package auth.utils; import java.io.BufferedInputStream; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.StringWriter; import java.security.KeyFactory; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; import java.security.interfaces.RSAPrivateKey; import java.security.spec.PKCS8EncodedKeySpec; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import org.joda.time.DateTime; import org.opensaml.Configuration; import org.opensaml.DefaultBootstrap; import org.opensaml.common.SAMLObjectBuilder; import org.opensaml.common.SAMLVersion; import org.opensaml.common.xml.SAMLConstants; import org.opensaml.saml2.core.Assertion; import org.opensaml.saml2.core.Attribute; import org.opensaml.saml2.core.AttributeQuery; import org.opensaml.saml2.core.AttributeStatement; import org.opensaml.saml2.core.AuthnContext; import org.opensaml.saml2.core.AuthnContextClassRef; import org.opensaml.saml2.core.AuthnContextComparisonTypeEnumeration; import org.opensaml.saml2.core.AuthnRequest; import org.opensaml.saml2.core.Conditions; import org.opensaml.saml2.core.EncryptedAssertion; import org.opensaml.saml2.core.Issuer; import org.opensaml.saml2.core.NameID; import org.opensaml.saml2.core.NameIDPolicy; import org.opensaml.saml2.core.RequestedAuthnContext; import org.opensaml.saml2.core.Response; import org.opensaml.saml2.core.StatusCode; import org.opensaml.saml2.core.Subject; import org.opensaml.saml2.encryption.Decrypter; import org.opensaml.ws.soap.common.SOAPObjectBuilder; import org.opensaml.ws.soap.soap11.Body; import org.opensaml.ws.soap.soap11.Envelope; import org.opensaml.xml.ConfigurationException; import org.opensaml.xml.XMLObject; import org.opensaml.xml.XMLObjectBuilder; import org.opensaml.xml.XMLObjectBuilderFactory; import org.opensaml.xml.encryption.InlineEncryptedKeyResolver; import org.opensaml.xml.io.Marshaller; import org.opensaml.xml.io.MarshallerFactory; import org.opensaml.xml.io.MarshallingException; import org.opensaml.xml.io.Unmarshaller; import org.opensaml.xml.io.UnmarshallerFactory; import org.opensaml.xml.parse.BasicParserPool; import org.opensaml.xml.parse.ParserPool; import org.opensaml.xml.security.SecurityConfiguration; import org.opensaml.xml.security.SecurityHelper; import org.opensaml.xml.security.credential.Credential; import org.opensaml.xml.security.keyinfo.StaticKeyInfoCredentialResolver; import org.opensaml.xml.security.x509.BasicX509Credential; import org.opensaml.xml.signature.SignableXMLObject; import org.opensaml.xml.signature.Signature; import org.opensaml.xml.signature.Signer; import org.opensaml.xml.util.Base64; import org.opensaml.xml.util.XMLHelper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.w3c.dom.Document; import org.w3c.dom.Element; public class SAMLUtils { private static Logger logger = LoggerFactory.getLogger(SAMLUtils.class); private static SAMLUtils instance; private ParserPool parserPool; private XMLObjectBuilderFactory builderFactory; private MarshallerFactory marshallerFactory; private UnmarshallerFactory unmarshallerFactory; private String derFile; private String pemFile; private String samlIssuerUrl; private String idpSoapUrl; private String samlUsername; private boolean debug; public static final String ATTR_DER_FILE = "derFile"; public static final String ATTR_PEM_FILE = "pemFile"; public static final String ATTR_ISSUER = "samlIssuerUrl"; public static final String ATTR_IDP_SOAP_URL = "idpSoapUrl"; public static final String ATTR_USERNAME = "samlUsername"; private SAMLUtils() { } public static final SAMLUtils getInstance() { if (instance == null) { synchronized (SAMLUtils.class) { instance = new SAMLUtils(); try { instance.init(); } catch (ConfigurationException e) { logger.error("Can't initialize openSAML!", e); } } } return instance; } private void init() throws ConfigurationException { DefaultBootstrap.bootstrap(); parserPool = new BasicParserPool(); builderFactory = Configuration.getBuilderFactory(); marshallerFactory = Configuration.getMarshallerFactory(); unmarshallerFactory = Configuration.getUnmarshallerFactory(); } public void setDERFileNm(String der) { this.derFile = der; } public void setPEMFileNm(String pem) { this.pemFile = pem; } public void setSamlIssuerUrl(String issuer) { this.samlIssuerUrl = issuer; } public void setIdpSoapUrl(String idpSoap) { this.idpSoapUrl = idpSoap; } public void setSamlUsername(String usernm) { this.samlUsername = usernm; } /** * Build the AuthnRequest message. * * @param forceAuthn * @return */ public String buildAuthnRequest(String samlIssuerUrl, String assertionConsumerServiceUrl, boolean forceAuthn) { //logger.debug("issuer: " + samlIssuerUrl + "; consumer: " + assertionConsumerServiceUrl); XMLObjectBuilder issuerBuilder = builderFactory.getBuilder(Issuer.DEFAULT_ELEMENT_NAME); Issuer issuer = (Issuer) issuerBuilder.buildObject(Issuer.DEFAULT_ELEMENT_NAME); issuer.setValue(samlIssuerUrl); // Create NameIDPolicy XMLObjectBuilder nameIdPolicyBuilder = builderFactory.getBuilder(NameIDPolicy.DEFAULT_ELEMENT_NAME); NameIDPolicy nameIdPolicy = (NameIDPolicy) nameIdPolicyBuilder.buildObject(NameIDPolicy.DEFAULT_ELEMENT_NAME); nameIdPolicy.setFormat(NameID.TRANSIENT); nameIdPolicy.setSPNameQualifier(samlIssuerUrl); nameIdPolicy.setAllowCreate(true); // Create AuthnContextClassRef XMLObjectBuilder authnContextClassRefBuilder = builderFactory .getBuilder(AuthnContextClassRef.DEFAULT_ELEMENT_NAME); AuthnContextClassRef authnContextClassRef = (AuthnContextClassRef) authnContextClassRefBuilder .buildObject(AuthnContextClassRef.DEFAULT_ELEMENT_NAME); authnContextClassRef.setAuthnContextClassRef(AuthnContext.PPT_AUTHN_CTX); // Create RequestedAuthnContext XMLObjectBuilder requestedAuthnContextBuilder = builderFactory .getBuilder(RequestedAuthnContext.DEFAULT_ELEMENT_NAME); RequestedAuthnContext requestedAuthnContext = (RequestedAuthnContext) requestedAuthnContextBuilder .buildObject(RequestedAuthnContext.DEFAULT_ELEMENT_NAME); requestedAuthnContext.setComparison(AuthnContextComparisonTypeEnumeration.EXACT); requestedAuthnContext.getAuthnContextClassRefs().add(authnContextClassRef); XMLObjectBuilder authnRequestBuilder = builderFactory.getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME); AuthnRequest authRequest = (AuthnRequest) authnRequestBuilder.buildObject(AuthnRequest.DEFAULT_ELEMENT_NAME); authRequest.setForceAuthn(forceAuthn); authRequest.setIsPassive(false); authRequest.setIssueInstant(new DateTime()); authRequest.setProtocolBinding(SAMLConstants.SAML2_POST_BINDING_URI); authRequest.setAssertionConsumerServiceURL(assertionConsumerServiceUrl); authRequest.setIssuer(issuer); authRequest.setNameIDPolicy(nameIdPolicy); authRequest.setRequestedAuthnContext(requestedAuthnContext); authRequest.setID(java.util.UUID.randomUUID().toString()); authRequest.setVersion(SAMLVersion.VERSION_20); // Now we must build our representation to put into the html form to be submitted to the idp Marshaller marshaller = marshallerFactory.getMarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME); org.w3c.dom.Element authDOM = null; try { authDOM = marshaller.marshall(authRequest); } catch (MarshallingException e) { logger.error("Failed marshalling the xml", e); return null; } StringWriter rspWrt = new StringWriter(); XMLHelper.writeNode(authDOM, rspWrt); String messageXML = rspWrt.toString(); saveToFile("AuthnRequest.xml", messageXML); return Base64.encodeBytes(messageXML.getBytes()); } public List<Assertion> decodeAssertions(Response resp) { ArrayList<Assertion> assertions = new ArrayList<Assertion>(); try { logger.trace("AuthnResponse Assertions=" + resp.getAssertions().size() + "; EncryptedAssertions=" + resp.getEncryptedAssertions().size()); for (Assertion assertion : resp.getAssertions()) { assertions.add(assertion); } int i = 0; for (EncryptedAssertion encryptedAssertion : resp.getEncryptedAssertions()) { Assertion assertion = decodeAssertion(encryptedAssertion); assertions.add(assertion); saveToFile("DecodedAuthnAssertion" + i++ + ".xml", assertion); } } catch (Exception e) { logger.error("failed decoding SAMLResponse", e); } return assertions; } public Response decodeSAMLResponse(String samlResponse) { try { byte[] decodedBytes = Base64.decode(samlResponse); ByteArrayInputStream bytesIn = new ByteArrayInputStream(decodedBytes); // InflaterInputStream inflater = new InflaterInputStream(bytesIn, new Inflater()); saveToFile("AuthnResponse.xml", decodedBytes); Document messageDoc = parserPool.parse(bytesIn); Element messageElem = messageDoc.getDocumentElement(); // logger.info("DOM was:\n{}", XMLHelper.nodeToString(messageElem)); Unmarshaller unmarshaller = unmarshallerFactory.getUnmarshaller(messageElem); if (unmarshaller == null) { logger.trace("Unable to unmarshall message, no unmarshaller registered for message element " + XMLHelper.getNodeQName(messageElem)); } Response resp = (Response) unmarshaller.unmarshall(messageElem); logger.trace("AuthnResponse StatusCode:" + resp.getStatus().getStatusCode().getValue()); return resp; } catch (Exception e) { logger.error("failed decoding SAMLResponse", e); } return null; } private Assertion decodeAssertion(EncryptedAssertion encryptedAssertion) { try { Credential decryptionCredential = getCredential(); StaticKeyInfoCredentialResolver skicr = new StaticKeyInfoCredentialResolver(decryptionCredential); Decrypter samlDecrypter = new Decrypter(null, skicr, new InlineEncryptedKeyResolver()); return samlDecrypter.decrypt(encryptedAssertion); } catch (Exception e) { logger.error("failed decrypting assertion!", e); } return null; } private Credential getCredential() { BasicX509Credential credential = null; try { // read private key File privateKeyFile = new File(derFile); FileInputStream inputStreamPrivateKey = new FileInputStream(privateKeyFile); byte[] encodedPrivateKey = new byte[(int) privateKeyFile.length()]; inputStreamPrivateKey.read(encodedPrivateKey); inputStreamPrivateKey.close(); PKCS8EncodedKeySpec privateKeySpec = new PKCS8EncodedKeySpec(encodedPrivateKey); RSAPrivateKey privateKey = (RSAPrivateKey) KeyFactory.getInstance("RSA").generatePrivate( privateKeySpec); // read the certificate InputStream inStream = new FileInputStream(pemFile); CertificateFactory cf = CertificateFactory.getInstance("X.509"); X509Certificate cert = (X509Certificate) cf.generateCertificate(inStream); // create credential credential = new BasicX509Credential(); credential.setEntityCertificate(cert); credential.setPrivateKey(privateKey); } catch (Exception e) { logger.error("failed getting credential!", e); } return credential; } private static String readInputStreamAsString(InputStream in) throws IOException { BufferedInputStream bis = new BufferedInputStream(in); ByteArrayOutputStream buf = new ByteArrayOutputStream(); int result = bis.read(); while (result != -1) { byte b = (byte) result; buf.write(b); result = bis.read(); } return buf.toString(); } public static String getXMLAsString(XMLObject obj) { Marshaller marshaller = Configuration.getMarshallerFactory().getMarshaller(obj); StringWriter rspWrt = new StringWriter(); try { org.w3c.dom.Element domEl = marshaller.marshall(obj); XMLHelper.writeNode(domEl, rspWrt); } catch (MarshallingException e) { logger.error("Failed marshalling the XMLObject!", e); } return rspWrt.toString(); } private AttributeQuery buildAttributeQuery(String name) { SAMLObjectBuilder<Issuer> issuerBuilder = (SAMLObjectBuilder<Issuer>) builderFactory .getBuilder(Issuer.DEFAULT_ELEMENT_NAME); Issuer issuer = issuerBuilder.buildObject(); issuer.setFormat(NameID.ENTITY); issuer.setValue(samlIssuerUrl); SAMLObjectBuilder<NameID> nameIdBuilder = (SAMLObjectBuilder<NameID>) builderFactory .getBuilder(NameID.DEFAULT_ELEMENT_NAME); NameID nameId = nameIdBuilder.buildObject(); nameId.setValue(name); SAMLObjectBuilder<Subject> subjectBuilder = (SAMLObjectBuilder<Subject>) builderFactory .getBuilder(Subject.DEFAULT_ELEMENT_NAME); Subject subject = subjectBuilder.buildObject(); subject.setNameID(nameId); SAMLObjectBuilder<AttributeQuery> attributeQueryBuilder = (SAMLObjectBuilder<AttributeQuery>) builderFactory .getBuilder(AttributeQuery.DEFAULT_ELEMENT_NAME); AttributeQuery query = attributeQueryBuilder.buildObject(); query.setID(java.util.UUID.randomUUID().toString()); query.setIssueInstant(new DateTime()); query.setIssuer(issuer); query.setSubject(subject); query.setVersion(SAMLVersion.VERSION_20); return query; } private String getSOAPMessage(AttributeQuery query) throws MarshallingException { SOAPObjectBuilder<Body> bodyBuilder = (SOAPObjectBuilder<Body>) builderFactory .getBuilder(Body.DEFAULT_ELEMENT_NAME); Body body = bodyBuilder.buildObject(); body.getUnknownXMLObjects().add(query); SOAPObjectBuilder<Envelope> envelopeBuilder = (SOAPObjectBuilder<Envelope>) builderFactory .getBuilder(Envelope.DEFAULT_ELEMENT_NAME); Envelope envelope = envelopeBuilder.buildObject(); envelope.setBody(body); Marshaller marshaller = marshallerFactory.getMarshaller(envelope); Element envelopeElem = marshaller.marshall(envelope); StringWriter writer = new StringWriter(); XMLHelper.writeNode(envelopeElem, writer); return writer.toString(); } private Response attributeQuery(String nameId) { try { AttributeQuery query = buildAttributeQuery(nameId); signRequest(query); String soapRequest = getSOAPMessage(query); saveToFile("AttrQueryRequest.xml", soapRequest); SendSoapMsg sender = new SendSoapMsg(idpSoapUrl); String soapResponse = sender.sendMsg(soapRequest); saveToFile("AttrQueryResponse.xml", soapResponse); ByteArrayInputStream bytes = new ByteArrayInputStream(soapResponse.getBytes()); Document messageDoc = parserPool.parse(bytes); Element messageElem = messageDoc.getDocumentElement(); Unmarshaller unmarshaller = unmarshallerFactory.getUnmarshaller(messageElem); Envelope envelope = (Envelope) unmarshaller.unmarshall(messageElem); Response resp = (Response) envelope.getBody().getOrderedChildren().get(0); return resp; } catch (Exception e) { logger.error("Failed retrieving attributes!", e); } return null; } private List<Attribute> getAttributesFromAssertions(List<Assertion> assertions) { ArrayList<Attribute> attrs = new ArrayList<Attribute>(); for (Assertion assertion : assertions) { attrs.addAll(getAttributesFromAssertion(assertion)); } return attrs; } private List<Attribute> getAttributesFromAssertion(Assertion assertion) { ArrayList<Attribute> attrs = new ArrayList<Attribute>(); for (AttributeStatement stmt : assertion.getAttributeStatements()) { attrs.addAll(stmt.getAttributes()); } return attrs; } private void saveToFile(String fileNm, String fileContent) { saveToFile(fileNm, fileContent.getBytes()); } private void saveToFile(String fileNm, byte[] fileContent) { if (!debug) return; try { String filePath = new File(".").getAbsolutePath() + "\\public\\xmlSample\\"; File f = new File(filePath + fileNm); if (f.exists()) { f.delete(); } f.createNewFile(); FileOutputStream fi = new FileOutputStream(f); fi.write(fileContent); fi.flush(); fi.close(); } catch (Exception e) { logger.error("Can't save to file", e); } } private void saveToFile(String fileNm, XMLObject obj) { Marshaller marshaller = marshallerFactory.getMarshaller(obj); org.w3c.dom.Element authDOM = null; try { authDOM = marshaller.marshall(obj); } catch (MarshallingException e) { logger.error("Failed marshalling the xml", e); return; } StringWriter rspWrt = new StringWriter(); XMLHelper.writeNode(authDOM, rspWrt); String messageXML = rspWrt.toString(); saveToFile(fileNm, messageXML.getBytes()); } private void signRequest(SignableXMLObject obj) { Credential credential = getCredential(); Signature signature = (Signature) Configuration.getBuilderFactory() .getBuilder(Signature.DEFAULT_ELEMENT_NAME).buildObject(Signature.DEFAULT_ELEMENT_NAME); signature.setSigningCredential(credential); SecurityConfiguration secConfig = Configuration.getGlobalSecurityConfiguration(); try { SecurityHelper.prepareSignatureParams(signature, credential, secConfig, null); obj.setSignature(signature); Configuration.getMarshallerFactory().getMarshaller(obj).marshall(obj); Signer.signObject(signature); } catch (Exception e) { logger.error("Can't prepare signature", e); } } private void logAttributes(List<Attribute> attrs) { if (!debug) return; for (Attribute attr : attrs) { String s = "Attribute name=" + attr.getName() + "; friendlyName=" + attr.getFriendlyName() + "; nameFormat=" + attr.getNameFormat() + "; values=" + attr.getAttributeValues().size() + " ["; for (XMLObject val : attr.getAttributeValues()) { s += "{qname:" + val.getElementQName() + ", qVal:" + val.getDOM().getNodeValue() + "}"; } s += "]"; logger.debug(s); } } public boolean processConditions(Conditions conditions) { // TODO return true; } public Map<String, String> getAttributeValue(Response authnResp, String nameId) { HashMap<String, String> attributes = new HashMap<String, String>(); ArrayList<Attribute> attrs = new ArrayList<Attribute>(); attrs.addAll(getAttributesFromAssertions(authnResp.getAssertions())); int i = 0; for (EncryptedAssertion encryptedAssertion : authnResp.getEncryptedAssertions()) { Assertion assertion = decodeAssertion(encryptedAssertion); saveToFile("DecodedAttrQueryAssertion" + i++ + ".xml", assertion); attrs.addAll(getAttributesFromAssertion(assertion)); } logger.trace("Found " + attrs.size() + " attributes in the AuthnResponse"); for (Attribute attr : attrs) { String nm = attr.getName() == null ? attr.getFriendlyName() : attr.getName(); attributes.put(nm, attr.getAttributeValues().get(0).getDOM().getTextContent()); } logAttributes(attrs); if (attributes.containsKey(samlUsername)) { return attributes; } logger.trace("Attribute '" + samlIssuerUrl + "' not found in AuthnResponse, make an AttributeQuery ..."); Response resp = attributeQuery(nameId); String statusCode = resp.getStatus().getStatusCode().getValue(); logger.trace("AttrQuery StatusCode:" + statusCode); if (!statusCode.equals(StatusCode.SUCCESS_URI)) { String statusMsg = resp.getStatus().getStatusMessage().getMessage(); logger.info("AttrQuery FAILED! " + statusMsg); } else { logger.trace("AttrQuery Assertions=" + resp.getAssertions().size() + "; EncryptedAssertions=" + resp.getEncryptedAssertions().size()); attrs.addAll(getAttributesFromAssertions(resp.getAssertions())); for (EncryptedAssertion encryptedAssertion : resp.getEncryptedAssertions()) { Assertion assertion = decodeAssertion(encryptedAssertion); attrs.addAll(getAttributesFromAssertion(assertion)); } logger.trace("Received " + attrs.size() + " attributes from AttributeQuery response"); } for (Attribute attr : attrs) { String nm = attr.getName() == null ? attr.getFriendlyName() : attr.getName(); attributes.put(nm, attr.getAttributeValues().get(0).getDOM().getTextContent()); } logAttributes(attrs); return attributes; } /** * @param debug */ public void setDebug(boolean debug) { this.debug = debug; } }