package de.groothues.mysaml.samples.ws;
import java.util.Iterator;
import java.util.Set;
import javax.xml.namespace.QName;
import javax.xml.soap.Node;
import javax.xml.soap.SOAPBody;
import javax.xml.soap.SOAPException;
import javax.xml.soap.SOAPFault;
import javax.xml.soap.SOAPHeader;
import javax.xml.soap.SOAPMessage;
import javax.xml.ws.handler.MessageContext;
import javax.xml.ws.handler.soap.SOAPHandler;
import javax.xml.ws.handler.soap.SOAPMessageContext;
import javax.xml.ws.soap.SOAPFaultException;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;
import de.groothues.mysaml.SamlContext;
import de.groothues.mysaml.SamlContextFactory;
import de.groothues.mysaml.assertion.AssertionType;
import de.groothues.mysaml.impl.DomHelper;
import de.groothues.mysaml.validator.ValidationResult;
public class SamlAuthenticationHandler implements SOAPHandler<SOAPMessageContext> {
@Override
public boolean handleMessage(SOAPMessageContext context) {
if (isInbound(context)) {
SOAPMessage message = context.getMessage();
try {
validateAssertion(message);
} catch (SOAPException e) {
// TODO Log exception
e.printStackTrace();
}
System.out.println("Message handled successfully.");
}
return true;
}
private void validateAssertion(SOAPMessage soapMessage) throws SOAPException {
Element assertionElement = getAssertion(soapMessage.getSOAPHeader());
if (assertionElement == null) {
createAndThrowSOAPFault(soapMessage, "No SAML Assertion found in SOAP Header");
}
SamlContext samlContext = SamlContextFactory.createSamlContext();
Document assertionDoc = toDocument(assertionElement);
AssertionType assertion = samlContext.getAssertionBuilder().unmarshal(assertionDoc);
ValidationResult ageValidation = samlContext.getAssertionAgeValidator().validate(assertion);
if (!ageValidation.isValid()) {
createAndThrowSOAPFault(soapMessage, ageValidation.getResultMessage());
}
ValidationResult periodValidation = samlContext.getAssertionValidityPeriodValidator().validate(assertion);
if (!periodValidation.isValid()) {
createAndThrowSOAPFault(soapMessage, periodValidation.getResultMessage());
}
ValidationResult audienceValidation = samlContext.getAssertionAudienceValidator().validate(assertion);
if (!audienceValidation.isValid()) {
createAndThrowSOAPFault(soapMessage, audienceValidation.getResultMessage());
}
ValidationResult signatureValidation = samlContext.getSignatureValidator().validate(assertionDoc);
if (!signatureValidation.isValid()) {
createAndThrowSOAPFault(soapMessage, signatureValidation.getResultMessage());
}
}
@Override
public boolean handleFault(SOAPMessageContext context) {
return true;
}
@Override
public void close(MessageContext context) {
}
@Override
public Set<QName> getHeaders() {
return null;
}
private boolean isInbound(SOAPMessageContext context) {
Boolean isOutbound = (Boolean) context.get(MessageContext.MESSAGE_OUTBOUND_PROPERTY);
return !isOutbound;
}
private void createAndThrowSOAPFault(SOAPMessage message, String faultString)
throws SOAPException {
SOAPHeader header = message.getSOAPHeader();
header.removeContents();
SOAPBody body = message.getSOAPPart().getEnvelope().getBody();
body.removeContents();
SOAPFault fault = body.addFault();
fault.setFaultString(faultString);
throw new SOAPFaultException(fault);
}
private Element getAssertion(SOAPHeader header) {
Iterator<?> it = header.examineAllHeaderElements();
while (it.hasNext()) {
Node node = (Node) it.next();
if (node.getNodeType() != Node.ELEMENT_NODE) {
continue;
}
Element element = (Element) node;
if (!element.getLocalName().equals("Security")) {
continue;
}
NodeList nodes = element.getElementsByTagNameNS("urn:oasis:names:tc:SAML:2.0:assertion", "Assertion");
return (Element)nodes.item(0);
}
return null;
}
private Document toDocument(Element assertion) {
Document doc = DomHelper.createNewDocument();
org.w3c.dom.Node imported = (org.w3c.dom.Node) doc.importNode(assertion, true);
doc.appendChild(imported);
return doc;
}
}