/* **********************************************************
* Copyright 2010 VMware, Inc. All rights reserved. -- VMware Confidential
* **********************************************************/
package com.emc.storageos.vasa.util;
import java.io.ByteArrayInputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.lang.management.ManagementFactory;
import java.security.KeyStore;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateExpiredException;
import java.security.cert.CertificateFactory;
import java.security.cert.CertificateNotYetValidException;
import java.security.cert.X509Certificate;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.Set;
import javax.management.MBeanInfo;
import javax.management.MBeanServer;
import javax.management.ObjectName;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.axis2.context.MessageContext;
import org.apache.axis2.transport.http.HTTPConstants;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import com.vmware.vim.vasa._1_0.InvalidArgument;
import com.vmware.vim.vasa._1_0.InvalidCertificate;
import com.vmware.vim.vasa._1_0.InvalidSession;
import com.vmware.vim.vasa._1_0.StorageFault;
//import de.hunsicker.jalopy.storage.History.Method;
/**
* Helper functions for handling SSL certificates.
*/
public class SSLUtil {
private String trustStoreFileName;
private String trustStorePassword;
private boolean mustUseSSL;
private Log log = LogFactory.getLog(SSLUtil.class);
public static final int HASH_LENGTH = 20;
public static String VASA_SESSIONID_STR = "VASASESSIONID";
/**
* Constructor
*/
public SSLUtil(String fileName, String password, boolean SSLOnly) {
trustStoreFileName = fileName;
trustStorePassword = password;
mustUseSSL = SSLOnly;
}
/**
* return the value of the given HTTP cookie
*
* @param cookieName
*/
public String getCookie(String cookieName) throws InvalidSession {
MessageContext currentMessageContext = MessageContext
.getCurrentMessageContext();
if (currentMessageContext == null) {
throw FaultUtil.InvalidSession("No current message context");
}
HttpServletRequest req = (HttpServletRequest) currentMessageContext
.getProperty(HTTPConstants.MC_HTTP_SERVLETREQUEST);
if (req == null) {
throw FaultUtil.InvalidSession("No HTTP Servlet Request");
}
Cookie[] cookies = req.getCookies();
if (cookies == null) {
return null;
}
for (int i = 0; i < cookies.length; i++) {
if (cookies[i].getName().equals(cookieName)) {
return cookies[i].getValue();
}
}
return null;
}
/**
* set the given HTTP cookie
*
* @param cookieName
* @param cookieValue
*/
public void setCookie(String cookieName, String cookieValue)
throws InvalidSession {
MessageContext currentMessageContext = MessageContext
.getCurrentMessageContext();
if (currentMessageContext == null) {
throw FaultUtil.InvalidSession("No current message context");
}
HttpServletResponse resp = (HttpServletResponse) currentMessageContext
.getProperty(HTTPConstants.MC_HTTP_SERVLETRESPONSE);
if (resp == null) {
throw FaultUtil.InvalidSession("No HTTP Servlet Response");
}
Cookie cookie = new Cookie(cookieName, cookieValue);
resp.addCookie(cookie);
}
/**
* setHttpResponse
*
* @param sc
*/
public void setHttpResponse(SessionContext sc) throws InvalidSession {
if (sc != null) {
setCookie(VASA_SESSIONID_STR, sc.getSessionId());
}
}
private void checkHttpForValidVASASession() throws InvalidSession {
/*
* Check for a valid VASA Session.
*/
final String methodName = "checkHttpForValidVASASession(): ";
String sessionId = getCookie(VASA_SESSIONID_STR);
log.debug(methodName + "Current session ID[" + sessionId + "]");
if (sessionId == null) {
throw FaultUtil
.InvalidSession("No valid VASA SessionId in HTTP header");
}
try {
SessionContext sc = SessionContext
.lookupSessionContextBySessionId(sessionId);
if (sc == null) {
throw FaultUtil.InvalidSession("Invalid VASA SessionId "
+ sessionId + " in HTTP header");
}
} catch (Exception e) {
throw FaultUtil.InvalidSession("Could not find session context "
+ e);
}
}
public void checkForUniqueVASASessionId() throws InvalidSession {
/*
* Check for a valid VASA Session.
*/
final String methodName = "checkForUniqueVASASessionId(): ";
String sessionId = getCookie(VASA_SESSIONID_STR);
log.debug(methodName + " Current session ID: [" + sessionId + "]");
if (sessionId != null) {
boolean isPreviouslyUsedSessionId = SessionContext
.IsPreviouslyUsed(sessionId);
log.debug(methodName + " Is this session ID used previously? ["
+ isPreviouslyUsedSessionId + "]");
if (isPreviouslyUsedSessionId) {
throw FaultUtil
.InvalidSession("This session Id is not unique. It is previously used:["
+ sessionId + "]");
}
}
}
private void checkHttpForValidSSLSession(HttpServletRequest req)
throws InvalidSession, InvalidCertificate {
/*
* Check for a valid SSL Session.
*/
X509Certificate[] sslCerts = (X509Certificate[]) req
.getAttribute("javax.servlet.request.X509Certificate");
if ((sslCerts == null) || (sslCerts.length == 0)) {
throw FaultUtil
.InvalidSession("No SSL Client Certificate attached to HTTPS session");
}
if (!certificateIsTrusted(sslCerts[0])) {
throw FaultUtil
.InvalidSession("No Trusted SSL Client Certificate attached to HTTPS session");
}
/**
* Note that a certificate that is trusted by this server, but one that
* has not necessarily been registered via a call to
* registerVASACertficate() will be accepted as valid.
*/
}
/**
* checkHttpRequest
*
* The term "Session" is overloaded. A Session can refer to either a SSL
* session or it can refer to a VASA session.
*
* If there is an error in either of the Session configurations, then this
* routine will throw the InvalidSession expection.
*
* @param validClientCertificateNeeded
* @param validSessionIdNeeed
*/
public String checkHttpRequest(boolean validSSLSessionNeeded,
boolean validVASASessionNeeded) throws InvalidSession {
final String methodName = "checkHttpRequest(): ";
try {
/*
* Check for a valid context.
*/
log.trace(methodName + "Entry with inputs validSSLSessionNeeded["
+ validSSLSessionNeeded + "] validVASASessionNeeded["
+ validVASASessionNeeded + "]");
MessageContext currentMessageContext = MessageContext
.getCurrentMessageContext();
if (currentMessageContext == null) {
throw FaultUtil.InvalidSession("No current message context");
}
String clientAddress = (String) currentMessageContext
.getProperty("REMOTE_ADDR");
// log.debug("Request from client at ip addr: " + clientAddress);
HttpServletRequest req = (HttpServletRequest) currentMessageContext
.getProperty(HTTPConstants.MC_HTTP_SERVLETREQUEST);
if (req == null) {
throw FaultUtil.InvalidSession("No HTTP Servlet Request");
}
/**
* Get SSL data
*/
String sslSessionId = (String) req
.getAttribute("javax.servlet.request.ssl_session");
if (sslSessionId == null) {
/**
* This is not an SSL connection. If the service is not allowing
* none-SSL connections, throw an exception. Otherwise check for
* a valid VASA session if necessary.
*/
if (!mustUseSSL) {
if (validVASASessionNeeded) {
checkHttpForValidVASASession();
}
log.trace(methodName + "Exit returning clientAddress["
+ clientAddress + "]");
return clientAddress;
} else {
throw FaultUtil.InvalidSession("Must use SSL connection");
}
}
/*
* At this point, it is known that there is a well formed HTTPS
* session.
*/
if (validSSLSessionNeeded) {
checkHttpForValidSSLSession(req);
}
if (validVASASessionNeeded) {
checkHttpForValidVASASession();
}
log.trace(methodName + "Exit returning clientAddress["
+ clientAddress + "]");
return clientAddress;
} catch (InvalidCertificate ic) {
// InvalidCertificate can be thrown by certificateIsTrusted
log.error(methodName + "invalid certificate exception ", ic);
throw FaultUtil.InvalidSession("Non trusted certificate.");
} catch (InvalidSession is) {
log.error(methodName + "invalid session exception ", is);
throw is;
} catch (Exception e) {
log.error(methodName + "Exception occured ", e);
throw FaultUtil
.InvalidSession(
"checkHttpSession unexpected exception. Convert to InvalidSession.",
e);
}
}
/**
* getCertificateThumbprint
*
* @param cert
*/
public String getCertificateThumbprint(Certificate cert)
throws InvalidArgument {
// Compute the SHA-1 hash of the certificate.
try {
byte[] encoded;
try {
encoded = cert.getEncoded();
} catch (CertificateEncodingException cee) {
throw FaultUtil.InvalidArgument(
"Error reading certificate encoding: "
+ cee.getMessage(), cee);
}
MessageDigest sha1;
try {
sha1 = MessageDigest.getInstance("SHA-1");
} catch (NoSuchAlgorithmException e) {
throw FaultUtil.InvalidArgument(
"Could not instantiate SHA-1 hash algorithm", e);
}
sha1.update(encoded);
byte[] hash = sha1.digest();
if (hash.length != HASH_LENGTH) {
throw FaultUtil.InvalidArgument("Computed thumbprint is "
+ hash.length + " bytes long, expected " + HASH_LENGTH);
}
StringBuilder thumbprintString = new StringBuilder(hash.length * 3);
for (int i = 0; i < hash.length; i++) {
if (i > 0) {
thumbprintString.append(":");
}
String hexByte = Integer.toHexString(0xFF & (int) hash[i]);
if (hexByte.length() == 1) {
thumbprintString.append("0");
}
thumbprintString.append(hexByte);
}
return thumbprintString.toString().toUpperCase();
} catch (InvalidArgument ia) {
throw ia;
} catch (Exception e) {
throw FaultUtil.InvalidArgument("Exception: " + e);
}
}
/**
* buildCertificate Build a certificate from a Base64 formatted, PKCS#7
* encoding of the certificate
*
* @param certString
*/
public Certificate buildCertificate(String certString)
throws InvalidCertificate {
try {
String base64Cert = formatCertificate(certString);
InputStream inBytes = new ByteArrayInputStream(
base64Cert.getBytes());
CertificateFactory cf = CertificateFactory.getInstance("X.509");
assert inBytes.available() > 0;
Certificate certificate = cf.generateCertificate(inBytes);
inBytes.close();
return certificate;
} catch (Exception e) {
log.debug("buildCertificate: error " + e
+ " converted to InvalidCertificate.");
throw FaultUtil.InvalidCertificate("Could not build certificate");
}
}
private String formatCertificate(String cert) {
final String HEADER = "-----BEGIN CERTIFICATE-----";
final String FOOTER = "-----END CERTIFICATE-----";
if (cert.trim().startsWith(HEADER)) {
return cert;
}
StringBuffer sb = new StringBuffer();
sb.append(HEADER);
sb.append("\n");
sb.append(cert.trim());
sb.append("\n");
sb.append(FOOTER);
return sb.toString();
}
/**
*
* Format of the alias is: "vpc-<integer>" For example, "vpc-3"
*/
private String getAlias(String clientAddress) throws InvalidCertificate {
int count = 0;
String certAliasBase = new String("vpc-");
String certAlias = certAliasBase.concat(Integer.toString(count));
while (getCertificateFromAlias(certAlias) != null) {
/**
* Need to make sure that certAlias is not already in the
* trustStore. If it is, create a different alias so as not to
* overwrite an existing certificate.
*/
count++;
certAlias = certAliasBase.concat(Integer.toString(count));
}
log.debug("getCertificateFromAlias() " + certAlias
+ " for certificate from " + clientAddress);
return certAlias;
}
/**
* addCertifcateToTrustStore
*
* @param certNameRoot
* ,
* @param certToAdd
*/
public void addCertificateToTrustStore(String certNameRoot,
Certificate certToAdd) throws InvalidArgument {
try {
KeyStore ts = KeyStore.getInstance("JKS");
FileInputStream is = new FileInputStream(trustStoreFileName);
ts.load(is, trustStorePassword.toCharArray());
is.close();
String certAlias = getAlias(certNameRoot);
ts.setCertificateEntry(certAlias, certToAdd);
FileOutputStream out = new FileOutputStream(trustStoreFileName);
ts.store(out, trustStorePassword.toCharArray());
out.close();
log.debug("Certificate with alias " + certAlias
+ " added to truststore");
} catch (Exception e) {
throw FaultUtil.InvalidArgument("Exception " + e);
}
}
/**
* removeCertifcateFromTrustStore
*
* @param certToAdd
*/
public void removeCertificateFromTrustStore(Certificate certToRemove)
throws InvalidArgument {
try {
KeyStore ts = KeyStore.getInstance("JKS");
FileInputStream is = new FileInputStream(trustStoreFileName);
ts.load(is, trustStorePassword.toCharArray());
is.close();
Enumeration<String> aliases = ts.aliases();
while (aliases.hasMoreElements()) {
String alias = aliases.nextElement();
if (ts.isCertificateEntry(alias)) {
X509Certificate tc = (X509Certificate) ts
.getCertificate(alias);
if (tc.equals(certToRemove)) {
ts.deleteEntry(alias);
}
}
}
FileOutputStream out = new FileOutputStream(trustStoreFileName);
ts.store(out, trustStorePassword.toCharArray());
out.close();
} catch (Exception e) {
throw FaultUtil.InvalidArgument("Exception " + e);
}
}
/**
* certificateIsTrusted
*
* @param certToCheck
*/
public boolean certificateIsTrusted(Certificate certToCheck)
throws InvalidCertificate {
try {
KeyStore ts = KeyStore.getInstance("JKS");
FileInputStream is = new FileInputStream(trustStoreFileName);
ts.load(is, trustStorePassword.toCharArray());
is.close();
Enumeration<String> aliases = ts.aliases();
while (aliases.hasMoreElements()) {
String alias = aliases.nextElement();
if (ts.isCertificateEntry(alias)) {
/**
* certificate is trusted
*/
X509Certificate tc = (X509Certificate) ts
.getCertificate(alias);
try {
tc.checkValidity();
/**
* certificate is valid
*/
if (tc.equals(certToCheck)) {
return true;
} else {
log.warn("Certificate [" + alias
+ "] is not valid.");
}
} catch (CertificateNotYetValidException e) {
log.error("Certificate is not yet valid: ", e);
throw e;
} catch (CertificateExpiredException e) {
log.error("Certificate is expired: ", e);
throw e;
} /*
* catch (InvalidCertificate e) {
* throw e;
* }
*/
}
}
return false;
} catch (Exception e) {
throw FaultUtil.InvalidCertificate("Exception: " + e);
}
}
/**
* getCertificateAlias
*
* @param cert
*/
public String getCertificateAlias(Certificate cert)
throws InvalidCertificate {
try {
KeyStore ts = KeyStore.getInstance("JKS");
FileInputStream is = new FileInputStream(trustStoreFileName);
ts.load(is, trustStorePassword.toCharArray());
is.close();
return ts.getCertificateAlias(cert);
} catch (Exception e) {
throw FaultUtil.InvalidCertificate("Exception: " + e);
}
}
/**
* getCertificateFromAlias return the certificate corresponding to this
* alias
*
* @param certString
*/
public Certificate getCertificateFromAlias(String certAlias)
throws InvalidCertificate {
try {
KeyStore ts = KeyStore.getInstance("JKS");
FileInputStream is = new FileInputStream(trustStoreFileName);
ts.load(is, trustStorePassword.toCharArray());
is.close();
return ts.getCertificate(certAlias);
} catch (Exception e) {
throw FaultUtil.InvalidCertificate("Exception: " + e);
}
}
/**
* thumbprintIsTrusted
*
* @param thumbprint
*/
public void thumbprintIsTrusted(String thumbprint)
throws InvalidCertificate {
try {
KeyStore ts = KeyStore.getInstance("JKS");
FileInputStream is = new FileInputStream(trustStoreFileName);
ts.load(is, trustStorePassword.toCharArray());
is.close();
Enumeration<String> aliases = ts.aliases();
while (aliases.hasMoreElements()) {
String alias = aliases.nextElement();
if (ts.isCertificateEntry(alias)) {
/**
* certificate is trusted
*/
X509Certificate tc = (X509Certificate) ts
.getCertificate(alias);
if (thumbprint.equals(getCertificateThumbprint(ts
.getCertificate(alias)))) {
try {
tc.checkValidity();
return;
} catch (Exception e) {
throw FaultUtil.InvalidCertificate(
"cert with thumprint is not valid", e);
}
}
}
}
throw FaultUtil
.InvalidCertificate("could not find certifcate that matches thumbprint");
} catch (InvalidCertificate ic) {
throw ic;
} catch (Exception e) {
throw FaultUtil.InvalidCertificate("Exception: " + e);
}
}
/**
* Stop and restart the SSL connection so that the tomcat server will
* re-read the certificates from the truststore file.
*
*/
public void refreshTrustStore() throws Exception {
try {
// MBeanServer mBeanServer = MBeanUtils.createServer();
MBeanServer mBeanServer = ManagementFactory
.getPlatformMBeanServer();
Set names = mBeanServer.queryNames(new ObjectName("*:*"), null);
Iterator it = names.iterator();
while (it.hasNext()) {
ObjectName oname = (ObjectName) it.next();
MBeanInfo minfo = mBeanServer.getMBeanInfo(oname);
String mBeanInfoClass = minfo.getClassName();
boolean condition = "org.apache.catalina.mbeans.ConnectorMBean"
.equals(mBeanInfoClass)
|| "org.mortbay.jetty.security.SslSocketConnector"
.equals(mBeanInfoClass)
|| "org.eclipse.jetty.server.ssl.SslSocketConnector"
.equals(mBeanInfoClass);
if (condition) {
String protocol = (String) mBeanServer.getAttribute(oname,
"protocol");
if (protocol.toLowerCase().startsWith("http")) {
boolean isSecure = ((mBeanServer.getAttribute(oname,
"secure") != null) && (mBeanServer
.getAttribute(oname, "secure").toString()
.equalsIgnoreCase("true")));
boolean isSchemeHTTPS = ((mBeanServer.getAttribute(
oname, "scheme") != null) && (mBeanServer
.getAttribute(oname, "scheme").toString()
.equalsIgnoreCase("https")));
if (isSecure && isSchemeHTTPS) {
log.debug("Restarting SSL Connector on port "
+ (Object) mBeanServer.getAttribute(oname,
"port"));
Object params[] = {};
String signature[] = {};
/**
* Stop and restart the connector to get it to
* re-read the certificate trustfile
*/
mBeanServer
.invoke(oname, "stop", params, signature);
mBeanServer.invoke(oname, "start", params,
signature);
}
}
}
}
} catch (Exception e) {
log.debug("Did not restart SSL Connector: " + e);
throw e;
}
}
public SessionContext getCurrentSessionContext() throws InvalidSession,
StorageFault {
String sessionId = getCookie(VASA_SESSIONID_STR);
return SessionContext.lookupSessionContextBySessionId(sessionId);
}
}