package org.talend.esb.mep.requestcallback.impl; import java.io.File; import java.net.MalformedURLException; import java.net.URL; import java.security.cert.X509Certificate; import java.util.List; import java.util.Map; import javax.wsdl.Definition; import javax.xml.namespace.QName; import org.apache.cxf.binding.soap.SoapMessage; import org.apache.cxf.headers.Header; import org.apache.cxf.helpers.CastUtils; import org.apache.cxf.helpers.DOMUtils; import org.apache.cxf.interceptor.Fault; import org.apache.cxf.message.Exchange; import org.apache.cxf.message.Message; import org.apache.cxf.phase.AbstractPhaseInterceptor; import org.apache.cxf.phase.Phase; import org.apache.cxf.rt.security.utils.SecurityUtils; import org.apache.cxf.service.model.BindingInfo; import org.apache.cxf.ws.addressing.AddressingProperties; import org.apache.cxf.ws.addressing.JAXWSAConstants; import org.apache.cxf.ws.addressing.MAPAggregator; import org.apache.cxf.ws.security.SecurityConstants; import org.apache.cxf.wsdl.WSDLManager; import org.apache.wss4j.dom.WSConstants; import org.apache.wss4j.dom.engine.WSSecurityEngineResult; import org.apache.wss4j.dom.handler.WSHandlerConstants; import org.apache.wss4j.dom.handler.WSHandlerResult; import org.talend.esb.mep.requestcallback.feature.CallContext; import org.talend.esb.mep.requestcallback.feature.CallbackInfo; import org.talend.esb.mep.requestcallback.feature.RequestCallbackFeature; import org.talend.esb.sam.agent.flowidprocessor.FlowIdProtocolHeaderCodec; import org.talend.esb.sam.agent.flowidprocessor.FlowIdSoapCodec; import org.talend.esb.sam.agent.message.FlowIdHelper; import org.w3c.dom.Element; /** * The Class CompressionOutInterceptor. */ public class RequestCallbackInInterceptor extends AbstractPhaseInterceptor<SoapMessage> { public RequestCallbackInInterceptor() { super(Phase.PRE_LOGICAL); addAfter(MAPAggregator.class.getName()); } /** * {@inheritDoc} */ @Override public void handleMessage(SoapMessage message) throws Fault { final Header callHeader = message.getHeader( RequestCallbackFeature.CALL_ID_HEADER_NAME); if (callHeader == null) { return; } final Exchange e = message.getExchange(); if (!e.isOneWay()) { e.setOneWay(true); } final Header callbackHeader = message.getHeader( RequestCallbackFeature.CALLBACK_ID_HEADER_NAME); if (callbackHeader == null) { doHandleRequestSoapMessage(message, callHeader); } else { doHandleCallbackSoapMessage(message, callHeader, callbackHeader); } } private void doHandleRequestSoapMessage( SoapMessage message, Header callHeader) throws Fault { CallContext ctx = setupCallContext(message, callHeader, null); // In case of using requestor certificate for response encryption, store certificate into callContext storeRequestorCertificate(message, ctx); } private void doHandleCallbackSoapMessage( SoapMessage message, Header callHeader, Header callbackHeader) throws Fault { setupCallContext(message, callHeader, callbackHeader); setupFlowId(message); } private CallContext setupCallContext( SoapMessage message, Header callHeader, Header callbackHeader) throws Fault { final AddressingProperties maps = getAddressingProperties(message); if (maps == null) { throw new IllegalStateException( "Request-Callback enabled but no WS-Addressing headers set. "); } CallContext ctx = new CallContext(); message.put(RequestCallbackFeature.CALLCONTEXT_PROPERTY_NAME, ctx); final QName operationName = QName.valueOf(maps.getAction().getValue()); if (!isGenericOperation(operationName)) { ctx.setOperationName(operationName); } ctx.setCallId(valueOf(callHeader)); if (callbackHeader != null) { ctx.setCallbackId(valueOf(callbackHeader)); } ctx.setRequestId(maps.getMessageID().getValue()); ctx.setReplyToAddress(maps.getReplyTo().getAddress().getValue()); ctx.setCorrelationId(getCorrelationId(message)); // Try to get SAM flowId in request message // to store it in CallContext for subsequent use // in callback message if (callbackHeader == null) { setupFlowId(message); } fillCallContext(ctx, message); return ctx; } private static String getCorrelationId(SoapMessage message) { Header h = message.getHeader(RequestCallbackFeature.CORRELATION_ID_HEADER_NAME); if(h!=null){ return valueOf(h); } return null; } private static AddressingProperties getAddressingProperties(SoapMessage message) { AddressingProperties maps = (AddressingProperties) message.get( JAXWSAConstants.CLIENT_ADDRESSING_PROPERTIES); if (maps == null) { maps = (AddressingProperties) message.get( JAXWSAConstants.ADDRESSING_PROPERTIES_INBOUND); } return maps; } private static String valueOf(Header header) { final Object headerObject = header.getObject(); if (headerObject == null) { return null; } if (headerObject instanceof String) { return (String) headerObject; } if (headerObject instanceof Element) { return DOMUtils.getContent((Element) headerObject); } return null; } private static boolean isGenericOperation(QName operationName) { final String name = operationName.getLocalPart(); return name.startsWith("http://cxf.apache.org/jaxws") && name.endsWith("InvokeOneWayRequest"); } public static void fillCallContext(CallContext callContext, SoapMessage message) { if (callContext.getOperationName() == null) { callContext.setOperationName((QName) message.get(Message.WSDL_OPERATION)); } callContext.setPortTypeName((QName) message.get(Message.WSDL_INTERFACE)); callContext.setServiceName((QName) message.get(Message.WSDL_SERVICE)); final BindingInfo bi = message.getExchange().getBinding().getBindingInfo(); callContext.setBindingId(bi == null ? "http://schemas.xmlsoap.org/wsdl/soap/" : bi.getBindingId()); URL wsdlLocation = resolveCallbackWsdlLocation(callContext.getServiceName(), message); if (wsdlLocation != null) { callContext.setWsdlLocation(wsdlLocation); } String flowId = FlowIdHelper.getFlowId(message); if (flowId != null && !flowId.isEmpty()) { callContext.setFlowId(flowId); } } private static URL resolveCallbackWsdlLocation(QName callbackService, SoapMessage message) { final WSDLManager wsdlManager = message.getExchange().getBus().getExtension(WSDLManager.class); for (Map.Entry<Object, Definition> entry : wsdlManager.getDefinitions().entrySet()) { if (entry.getValue().getService(callbackService) != null) { final Object key = entry.getKey(); if (key instanceof URL) { return asCallbackWsdlURL((URL) entry.getKey()); } if (key instanceof String) { final String loc = (String) key; if (loc.startsWith("file:") || loc.indexOf("://") > 0) { try { return asCallbackWsdlURL(new URL(loc)); } catch (MalformedURLException e) { throw new IllegalStateException("Corrupted WSDL location URL: ", e); } } final File wsdlFile = new File(loc); if (wsdlFile.exists()) { return toCallbackWsdlURL(wsdlFile); } // classpath resolution will only work where the loader // of request-callback classes is in the same classloader // as the invoking application. Otherwise, the WSDL location // must be provided as URL or absolute file path. URL classpathWsdlURL = CallContext.class.getClassLoader().getResource(loc); if (classpathWsdlURL != null) { return asCallbackWsdlURL(classpathWsdlURL); } return null; } } } return null; } private static URL asCallbackWsdlURL(URL wsdlURL) { if (wsdlURL == null) { return null; } final CallbackInfo cbInfo = CallContext.createCallbackInfo(wsdlURL); if (cbInfo.getCallbackServiceName() == null) { // old-style callback definition without callback service. return null; } final String callbackWsdlLocation = cbInfo.getSpecificCallbackSenderWsdlLocation(null); try { return callbackWsdlLocation == null ? wsdlURL : new URL(callbackWsdlLocation); } catch (MalformedURLException e) { throw new IllegalStateException("Unexpected URL creation problem: ", e); } } private static URL toCallbackWsdlURL(File wsdlFile) { if (!(wsdlFile.isFile() && wsdlFile.canRead())) { throw new IllegalStateException("File " + wsdlFile.getName() + " is not a readable file. "); } try { final URL wsdlURL = wsdlFile.toURI().toURL(); final CallbackInfo cbInfo = CallContext.createCallbackInfo(wsdlURL); if (cbInfo.getCallbackServiceName() == null) { // old-style callback definition without callback service. return null; } return wsdlURL; } catch (MalformedURLException e) { throw new IllegalArgumentException("Cannot create URL for WSDL file location. ", e); } } /** * This functions reads SAM flowId and sets it * as message property for subsequent store in CallContext * @param message */ private static void setupFlowId(SoapMessage message) { String flowId = FlowIdHelper.getFlowId(message); if (flowId == null) { flowId = FlowIdProtocolHeaderCodec.readFlowId(message); } if (flowId == null) { flowId = FlowIdSoapCodec.readFlowId(message); } if (flowId == null) { Exchange ex = message.getExchange(); if (null!=ex){ Message reqMsg = ex.getOutMessage(); if ( null != reqMsg) { flowId = FlowIdHelper.getFlowId(reqMsg); } } } if (flowId != null && !flowId.isEmpty()) { FlowIdHelper.setFlowId(message, flowId); } } private static void storeRequestorCertificate(Message message, CallContext callContext) { String encrUser = (String) SecurityUtils.getSecurityPropertyValue(SecurityConstants.ENCRYPT_USERNAME, message); if (WSHandlerConstants.USE_REQ_SIG_CERT.equals(encrUser)) { X509Certificate reqSignCert = getReqSigCert(message); callContext.setRequestorSignatureCertificate(reqSignCert); } } // TODO: Currently this method is private in CXF AbstractBindingBuilder. // Refactor the method into public CXF utility and reuse it from CXF instead copy&paste private static X509Certificate getReqSigCert(Message message) { List<WSHandlerResult> results = CastUtils.cast((List<?>) message.getExchange().getInMessage().get(WSHandlerConstants.RECV_RESULTS)); if (results == null) { return null; } /* * Scan the results for a matching actor. Use results only if the * receiving Actor and the sending Actor match. */ for (WSHandlerResult rResult : results) { List<WSSecurityEngineResult> wsSecEngineResults = rResult .getResults(); /* * Scan the results for the first Signature action. Use the * certificate of this Signature to set the certificate for the * encryption action :-). */ for (WSSecurityEngineResult wser : wsSecEngineResults) { Integer actInt = (Integer) wser .get(WSSecurityEngineResult.TAG_ACTION); if (actInt.intValue() == WSConstants.SIGN) { return (X509Certificate) wser .get(WSSecurityEngineResult.TAG_X509_CERTIFICATE); } } } return null; } }