package com.hwlcn.ldap.util.ssl;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.IOException;
import java.io.PrintStream;
import java.security.MessageDigest;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Date;
import java.util.concurrent.ConcurrentHashMap;
import javax.net.ssl.X509TrustManager;
import javax.security.auth.x500.X500Principal;
import com.hwlcn.core.annotation.NotMutable;
import com.hwlcn.core.annotation.ThreadSafety;
import com.hwlcn.ldap.util.ThreadSafetyLevel;
import static com.hwlcn.ldap.util.Debug.*;
import static com.hwlcn.ldap.util.StaticUtils.*;
import static com.hwlcn.ldap.util.ssl.SSLMessages.*;
@NotMutable()
@ThreadSafety(level=ThreadSafetyLevel.COMPLETELY_THREADSAFE)
public final class PromptTrustManager
implements X509TrustManager
{
private static final MessageDigest MD5;
private static final MessageDigest SHA1;
static
{
MessageDigest d = null;
try
{
d = MessageDigest.getInstance("MD5");
}
catch (final Exception e)
{
debugException(e);
throw new RuntimeException(e);
}
MD5 = d;
d = null;
try
{
d = MessageDigest.getInstance("SHA-1");
}
catch (final Exception e)
{
debugException(e);
throw new RuntimeException(e);
}
SHA1 = d;
}
private final boolean examineValidityDates;
private final ConcurrentHashMap<String,Boolean> acceptedCerts;
private final InputStream in;
private final PrintStream out;
private final String acceptedCertsFile;
public PromptTrustManager()
{
this(null, true, null, null);
}
public PromptTrustManager(final String acceptedCertsFile)
{
this(acceptedCertsFile, true, null, null);
}
public PromptTrustManager(final String acceptedCertsFile,
final boolean examineValidityDates,
final InputStream in, final PrintStream out)
{
this.acceptedCertsFile = acceptedCertsFile;
this.examineValidityDates = examineValidityDates;
if (in == null)
{
this.in = System.in;
}
else
{
this.in = in;
}
if (out == null)
{
this.out = System.out;
}
else
{
this.out = out;
}
acceptedCerts = new ConcurrentHashMap<String,Boolean>();
if (acceptedCertsFile != null)
{
BufferedReader r = null;
try
{
final File f = new File(acceptedCertsFile);
if (f.exists())
{
r = new BufferedReader(new FileReader(f));
while (true)
{
final String line = r.readLine();
if (line == null)
{
break;
}
acceptedCerts.put(line, false);
}
}
}
catch (Exception e)
{
debugException(e);
}
finally
{
if (r != null)
{
try
{
r.close();
}
catch (Exception e)
{
debugException(e);
}
}
}
}
}
private void writeCacheFile()
throws IOException
{
final File tempFile = new File(acceptedCertsFile + ".new");
BufferedWriter w = null;
try
{
w = new BufferedWriter(new FileWriter(tempFile));
for (final String certBytes : acceptedCerts.keySet())
{
w.write(certBytes);
w.newLine();
}
}
finally
{
if (w != null)
{
w.close();
}
}
final File cacheFile = new File(acceptedCertsFile);
if (cacheFile.exists())
{
final File oldFile = new File(acceptedCertsFile + ".previous");
if (oldFile.exists())
{
oldFile.delete();
}
cacheFile.renameTo(oldFile);
}
tempFile.renameTo(cacheFile);
}
private synchronized void checkCertificateChain(final X509Certificate[] chain,
final boolean serverCert)
throws CertificateException
{
String validityWarning = null;
final Date currentDate = new Date();
final X509Certificate c = chain[0];
if (examineValidityDates)
{
if (currentDate.before(c.getNotBefore()))
{
validityWarning = WARN_PROMPT_NOT_YET_VALID.get();
}
else if (currentDate.after(c.getNotAfter()))
{
validityWarning = WARN_PROMPT_EXPIRED.get();
}
}
if ((! examineValidityDates) || (validityWarning == null))
{
final String certBytes = toLowerCase(toHex(c.getSignature()));
final Boolean accepted = acceptedCerts.get(certBytes);
if (accepted != null)
{
if ((validityWarning == null) || (! examineValidityDates) ||
Boolean.TRUE.equals(accepted))
{
return;
}
}
}
if (serverCert)
{
out.println(INFO_PROMPT_SERVER_HEADING.get());
}
else
{
out.println(INFO_PROMPT_CLIENT_HEADING.get());
}
out.println('\t' + INFO_PROMPT_SUBJECT.get(
c.getSubjectX500Principal().getName(X500Principal.CANONICAL)));
out.println("\t\t" + INFO_PROMPT_MD5_FINGERPRINT.get(
getFingerprint(c, MD5)));
out.println("\t\t" + INFO_PROMPT_SHA1_FINGERPRINT.get(
getFingerprint(c, SHA1)));
for (int i=1; i < chain.length; i++)
{
out.println('\t' + INFO_PROMPT_ISSUER_SUBJECT.get(i,
chain[i].getSubjectX500Principal().getName(
X500Principal.CANONICAL)));
out.println("\t\t" + INFO_PROMPT_MD5_FINGERPRINT.get(
getFingerprint(chain[i], MD5)));
out.println("\t\t" + INFO_PROMPT_SHA1_FINGERPRINT.get(
getFingerprint(chain[i], SHA1)));
}
out.println(INFO_PROMPT_VALIDITY.get(String.valueOf(c.getNotBefore()),
String.valueOf(c.getNotAfter())));
if (chain.length == 1)
{
out.println();
out.println(WARN_PROMPT_SELF_SIGNED.get());
}
if (validityWarning != null)
{
out.println();
out.println(validityWarning);
}
final BufferedReader reader = new BufferedReader(new InputStreamReader(in));
while (true)
{
try
{
out.println();
out.println(INFO_PROMPT_MESSAGE.get());
out.flush();
final String line = reader.readLine();
if (line.equalsIgnoreCase("y") || line.equalsIgnoreCase("yes"))
{
break;
}
else if (line.equalsIgnoreCase("n") || line.equalsIgnoreCase("no"))
{
throw new CertificateException(
ERR_CERTIFICATE_REJECTED_BY_USER.get());
}
}
catch (CertificateException ce)
{
throw ce;
}
catch (Exception e)
{
debugException(e);
}
}
final String certBytes = toLowerCase(toHex(c.getSignature()));
acceptedCerts.put(certBytes, (validityWarning != null));
if (acceptedCertsFile != null)
{
try
{
writeCacheFile();
}
catch (Exception e)
{
debugException(e);
}
}
}
private static String getFingerprint(final X509Certificate c,
final MessageDigest d)
throws CertificateException
{
final byte[] encodedCertBytes = c.getEncoded();
final byte[] digestBytes;
synchronized (d)
{
digestBytes = d.digest(encodedCertBytes);
}
final StringBuilder buffer = new StringBuilder(3 * encodedCertBytes.length);
toHex(digestBytes, ":", buffer);
return buffer.toString();
}
public boolean examineValidityDates()
{
return examineValidityDates;
}
public void checkClientTrusted(final X509Certificate[] chain,
final String authType)
throws CertificateException
{
checkCertificateChain(chain, false);
}
public void checkServerTrusted(final X509Certificate[] chain,
final String authType)
throws CertificateException
{
checkCertificateChain(chain, true);
}
public X509Certificate[] getAcceptedIssuers()
{
return new X509Certificate[0];
}
}