package org.limewire.io; import java.io.ByteArrayInputStream; import java.io.DataInputStream; import java.io.EOFException; import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; import java.io.StreamCorruptedException; import java.security.MessageDigest; /** * Reads a stream of bytes written by a {@link SecureOutputStream} and checks * if the bytes are still valid. <code>SecureInputStream</code> throws * exceptions upon problems reading the <code>SecureInputStream</code>. */ public class SecureInputStream extends FilterInputStream { private final MessageDigest md; private final byte[] buffer; private int pos = 0; private int length = -1; public SecureInputStream(InputStream in) throws IOException { this(in, new CRC32MessageDigest()); } public SecureInputStream(InputStream in, MessageDigest md) throws IOException { super(in); if (md == null) { throw new NullPointerException("MessageDigest is null"); } // Read the length of the header int length = 0; for (int i = 0; i < 4; i++) { int b = in.read(); if (b < 0) { throw new EOFException("Couldn't read the length of the header"); } length = (length << 8) | (b & 0xFF); } // A simple sanity check. The length cannot be negative // nor greater than X (4+4+4+name of algorithm). if (length <= 0 || length >= 512) { throw new StreamCorruptedException("Invalid length of the header: " + length); } // Read the actual header byte[] header = new byte[length]; for (int i = 0; i < header.length; i++) { int b = in.read(); if (b < 0) { throw new EOFException("Couldn't read the header"); } header[i] = (byte)(b & 0xFF); } md.update(header, 0, header.length); byte[] actual = md.digest(); // Read the expected checksum and compare it on the fly for (int i = 0; i < actual.length; i++) { int b = in.read(); if (b < 0) { throw new EOFException("Couldn't read the checksum of length " + actual.length); } if (actual[i] != (byte)(b & 0xFF)) { throw new StreamCorruptedException("Header checksums do not match"); } } String algorithm = null; int digestLength = 0; int blockSize = 0; // Get the fields from the header ByteArrayInputStream bias = new ByteArrayInputStream(header); DataInputStream dis = new DataInputStream(bias); algorithm = dis.readUTF(); digestLength = dis.readInt(); blockSize = dis.readInt(); dis.close(); // Make some final sanity checks if (!algorithm.equals(md.getAlgorithm())) { throw new StreamCorruptedException("Expected a MessageDigest of type " + algorithm + " but is " + md.getAlgorithm()); } if (digestLength != md.getDigestLength()) { throw new StreamCorruptedException("Expected a MessageDigest with length " + digestLength + " but is " + md.getDigestLength()); } md.reset(); this.md = md; this.buffer = new byte[blockSize]; } /** * Returns the block (buffer) size of the stream. */ public int getBlockSize() { return buffer.length; } /** * Returns the MessageDigest. */ public MessageDigest getMessageDigest() { return md; } private int refill() throws IOException { assert (pos >= length); // Reset everything pos = 0; length = 0; md.reset(); // Fill the buffer while(length < buffer.length) { int r = in.read(); if (r < 0) { break; } buffer[length++] = (byte)(r & 0xFF); } // Is EOF? if (length == 0) { return -1; } // Get the actual length of the payload int digestLength = md.getDigestLength(); length -= digestLength; if (length <= 0) { throw new StreamCorruptedException("Illegal payload length: " + length); } // Compute the hash md.update(buffer, 0, length); // Compare the actual hash with the expected hash byte[] digest = md.digest(); assert (digest.length == digestLength); for (int i = 0; i < digest.length; i++) { if (digest[i] != buffer[length+i]) { throw new StreamCorruptedException("Checksums do not match"); } } return length; } @Override public int read() throws IOException { if (pos >= length) { refill(); if (pos >= length) { return -1; } } return buffer[pos++] & 0xFF; } @Override public int read(byte[] b, int off, int len) throws IOException { if (b == null) { throw new NullPointerException(); } else if ((off < 0) || (off > b.length) || (len < 0) || ((off + len) > b.length) || ((off + len) < 0)) { throw new IndexOutOfBoundsException(); } else if (len == 0) { return 0; } int total = 0; while(total < len) { if (pos >= length) { refill(); if (pos >= length) { // EOF break; } } int copy = Math.min(length-pos, len-total); System.arraycopy(buffer, pos, b, off+total, copy); pos += copy; total += copy; } return (total > 0 ? total : -1); } }