package org.bouncycastle.crypto.tls; import java.util.Enumeration; import java.util.Hashtable; import org.bouncycastle.crypto.Digest; import org.bouncycastle.util.Shorts; /** * Buffers input until the hash algorithm is determined. */ class DeferredHash implements TlsHandshakeHash { protected static final int BUFFERING_HASH_LIMIT = 4; protected TlsContext context; private DigestInputBuffer buf; private Hashtable hashes; private Short prfHashAlgorithm; DeferredHash() { this.buf = new DigestInputBuffer(); this.hashes = new Hashtable(); this.prfHashAlgorithm = null; } private DeferredHash(Short prfHashAlgorithm, Digest prfHash) { this.buf = null; this.hashes = new Hashtable(); this.prfHashAlgorithm = prfHashAlgorithm; hashes.put(prfHashAlgorithm, prfHash); } public void init(TlsContext context) { this.context = context; } public TlsHandshakeHash notifyPRFDetermined() { int prfAlgorithm = context.getSecurityParameters().getPrfAlgorithm(); if (prfAlgorithm == PRFAlgorithm.tls_prf_legacy) { CombinedHash legacyHash = new CombinedHash(); legacyHash.init(context); buf.updateDigest(legacyHash); return legacyHash.notifyPRFDetermined(); } this.prfHashAlgorithm = Shorts.valueOf(TlsUtils.getHashAlgorithmForPRFAlgorithm(prfAlgorithm)); checkTrackingHash(prfHashAlgorithm); return this; } public void trackHashAlgorithm(short hashAlgorithm) { if (buf == null) { throw new IllegalStateException("Too late to track more hash algorithms"); } checkTrackingHash(Shorts.valueOf(hashAlgorithm)); } public void sealHashAlgorithms() { checkStopBuffering(); } public TlsHandshakeHash stopTracking() { Digest prfHash = TlsUtils.cloneHash(prfHashAlgorithm.shortValue(), (Digest)hashes.get(prfHashAlgorithm)); if (buf != null) { buf.updateDigest(prfHash); } DeferredHash result = new DeferredHash(prfHashAlgorithm, prfHash); result.init(context); return result; } public Digest forkPRFHash() { checkStopBuffering(); if (buf != null) { Digest prfHash = TlsUtils.createHash(prfHashAlgorithm.shortValue()); buf.updateDigest(prfHash); return prfHash; } return TlsUtils.cloneHash(prfHashAlgorithm.shortValue(), (Digest)hashes.get(prfHashAlgorithm)); } public byte[] getFinalHash(short hashAlgorithm) { Digest d = (Digest)hashes.get(Shorts.valueOf(hashAlgorithm)); if (d == null) { throw new IllegalStateException("HashAlgorithm." + HashAlgorithm.getText(hashAlgorithm) + " is not being tracked"); } d = TlsUtils.cloneHash(hashAlgorithm, d); if (buf != null) { buf.updateDigest(d); } byte[] bs = new byte[d.getDigestSize()]; d.doFinal(bs, 0); return bs; } public String getAlgorithmName() { throw new IllegalStateException("Use fork() to get a definite Digest"); } public int getDigestSize() { throw new IllegalStateException("Use fork() to get a definite Digest"); } public void update(byte input) { if (buf != null) { buf.write(input); return; } Enumeration e = hashes.elements(); while (e.hasMoreElements()) { Digest hash = (Digest)e.nextElement(); hash.update(input); } } public void update(byte[] input, int inOff, int len) { if (buf != null) { buf.write(input, inOff, len); return; } Enumeration e = hashes.elements(); while (e.hasMoreElements()) { Digest hash = (Digest)e.nextElement(); hash.update(input, inOff, len); } } public int doFinal(byte[] output, int outOff) { throw new IllegalStateException("Use fork() to get a definite Digest"); } public void reset() { if (buf != null) { buf.reset(); return; } Enumeration e = hashes.elements(); while (e.hasMoreElements()) { Digest hash = (Digest)e.nextElement(); hash.reset(); } } protected void checkStopBuffering() { if (buf != null && hashes.size() <= BUFFERING_HASH_LIMIT) { Enumeration e = hashes.elements(); while (e.hasMoreElements()) { Digest hash = (Digest)e.nextElement(); buf.updateDigest(hash); } this.buf = null; } } protected void checkTrackingHash(Short hashAlgorithm) { if (!hashes.containsKey(hashAlgorithm)) { Digest hash = TlsUtils.createHash(hashAlgorithm.shortValue()); hashes.put(hashAlgorithm, hash); } } }