package org.limewire.io; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.StreamCorruptedException; import java.security.MessageDigest; import junit.framework.Test; import org.limewire.util.BaseTestCase; import org.limewire.util.StringUtils; public class SecureInputOutputTest extends BaseTestCase { public SecureInputOutputTest(String name) { super(name); } public static void main(String[] args) { junit.textui.TestRunner.run(suite()); } public static Test suite() { return buildTestSuite(SecureInputOutputTest.class); } public void testSecureInputOutput() throws IOException { // Use some odd number int blockSize = 123; ByteArrayOutputStream baos = new ByteArrayOutputStream(); SecureOutputStream sos = new SecureOutputStream(baos, blockSize); DataOutputStream dos = new DataOutputStream(sos); Class[] types = new Class[] { Boolean.class, Byte.class, Short.class, Integer.class, Float.class, Long.class, Double.class, String.class }; String[] str = new String[] { "Hello World", "LimeWire", "Mojito", "Gnutella Network" }; for (int i = 0, j = 0, k = 0; i < 4096; i++) { Class clazz = types[i % types.length]; if (clazz.equals(Boolean.class)) { dos.writeBoolean((j++ % 2) == 0); } else if (clazz.equals(Byte.class)) { dos.write((byte)i); } else if (clazz.equals(Short.class)) { dos.writeShort((short)i); } else if (clazz.equals(Integer.class)) { dos.writeInt(i); } else if (clazz.equals(Float.class)) { dos.writeFloat(i); } else if (clazz.equals(Long.class)) { dos.writeLong(i); } else if (clazz.equals(Double.class)) { dos.writeDouble(i); } else if (clazz.equals(String.class)) { dos.writeUTF(str[k++ % str.length]); } } dos.close(); ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); SecureInputStream sis = new SecureInputStream(bais); DataInputStream dis = new DataInputStream(sis); assertEquals(blockSize, sis.getBlockSize()); assertEquals(sos.getBlockSize(), sis.getBlockSize()); assertEquals(sos.getMessageDigest().getAlgorithm(), sis.getMessageDigest().getAlgorithm()); for (int i = 0, j = 0, k = 0; i < 4096; i++) { Class clazz = types[i % types.length]; if (clazz.equals(Boolean.class)) { assertEquals((j++ % 2) == 0, dis.readBoolean()); } else if (clazz.equals(Byte.class)) { assertEquals((byte)i, dis.readByte()); } else if (clazz.equals(Short.class)) { assertEquals((short)i, dis.readShort()); } else if (clazz.equals(Integer.class)) { assertEquals(i, dis.readInt()); } else if (clazz.equals(Float.class)) { assertEquals((float)i, dis.readFloat()); } else if (clazz.equals(Long.class)) { assertEquals(i, dis.readLong()); } else if (clazz.equals(Double.class)) { assertEquals((double)i, dis.readDouble()); } else if (clazz.equals(String.class)) { assertEquals(str[k++ % str.length], dis.readUTF()); } } dis.close(); } public void testSecureInputOutputHeader() throws IOException { String algorithm = "This is a very long MessageDigest algorithm name! " + "In fact its length must be at least 256 bytes to make sure " + "SecureOutputStream's length field in the header is utilitzing " + "more than one byte for the length. So what I am doing here is " + "to define an extremly long algorithm name for the fake MessageDigest " + "implementation. The goal is to make sure the for-loop that re-assembles " + "the length field works correctly"; assertGreaterThan(0xFF, StringUtils.toAsciiBytes(algorithm).length); MessageDigest md = new FakeMessageDigest(algorithm, 16); // Use some odd number int blockSize = 123; ByteArrayOutputStream baos = new ByteArrayOutputStream(); SecureOutputStream sos = new SecureOutputStream(baos, md, blockSize); DataOutputStream dos = new DataOutputStream(sos); Class[] types = new Class[] { Boolean.class, Byte.class, Short.class, Integer.class, Float.class, Long.class, Double.class, String.class }; String[] str = new String[] { "Hello World", "LimeWire", "Mojito", "Gnutella Network" }; for (int i = 0, j = 0, k = 0; i < 4096; i++) { Class clazz = types[i % types.length]; if (clazz.equals(Boolean.class)) { dos.writeBoolean((j++ % 2) == 0); } else if (clazz.equals(Byte.class)) { dos.write((byte)i); } else if (clazz.equals(Short.class)) { dos.writeShort((short)i); } else if (clazz.equals(Integer.class)) { dos.writeInt(i); } else if (clazz.equals(Float.class)) { dos.writeFloat(i); } else if (clazz.equals(Long.class)) { dos.writeLong(i); } else if (clazz.equals(Double.class)) { dos.writeDouble(i); } else if (clazz.equals(String.class)) { dos.writeUTF(str[k++ % str.length]); } } dos.close(); ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); SecureInputStream sis = new SecureInputStream(bais, md); DataInputStream dis = new DataInputStream(sis); assertEquals(blockSize, sis.getBlockSize()); assertEquals(sos.getBlockSize(), sis.getBlockSize()); assertEquals(sos.getMessageDigest().getAlgorithm(), sis.getMessageDigest().getAlgorithm()); for (int i = 0, j = 0, k = 0; i < 4096; i++) { Class clazz = types[i % types.length]; if (clazz.equals(Boolean.class)) { assertEquals((j++ % 2) == 0, dis.readBoolean()); } else if (clazz.equals(Byte.class)) { assertEquals((byte)i, dis.readByte()); } else if (clazz.equals(Short.class)) { assertEquals((short)i, dis.readShort()); } else if (clazz.equals(Integer.class)) { assertEquals(i, dis.readInt()); } else if (clazz.equals(Float.class)) { assertEquals((float)i, dis.readFloat()); } else if (clazz.equals(Long.class)) { assertEquals(i, dis.readLong()); } else if (clazz.equals(Double.class)) { assertEquals((double)i, dis.readDouble()); } else if (clazz.equals(String.class)) { assertEquals(str[k++ % str.length], dis.readUTF()); } } dis.close(); } private static class FakeMessageDigest extends MessageDigest { private int digestLength; public FakeMessageDigest(String algorithm, int diesgtLength) { super(algorithm); this.digestLength = diesgtLength; } @Override protected int engineGetDigestLength() { return digestLength; } @Override protected byte[] engineDigest() { return new byte[digestLength]; } @Override protected void engineReset() { } @Override protected void engineUpdate(byte input) { } @Override protected void engineUpdate(byte[] input, int offset, int len) { } } public void testCorruptHeader() throws IOException { ByteArrayOutputStream baos = new ByteArrayOutputStream(); SecureOutputStream sos = new SecureOutputStream(baos, 128); DataOutputStream dos = new DataOutputStream(sos); dos.writeUTF("Hello World!"); dos.close(); byte[] data = baos.toByteArray(); data[5] = (byte)~data[5]; try { ByteArrayInputStream bais = new ByteArrayInputStream(data); SecureInputStream sis = new SecureInputStream(bais); DataInputStream dis = new DataInputStream(sis); fail("Should have thrown a StreamCorruptedException!"); dis.close(); } catch (StreamCorruptedException ignored) {} } public void testCorruptPayload() throws IOException { ByteArrayOutputStream baos = new ByteArrayOutputStream(); SecureOutputStream sos = new SecureOutputStream(baos, 128); DataOutputStream dos = new DataOutputStream(sos); for (int i = 0; i < 128; i++) { dos.writeShort(i); } dos.close(); byte[] data = baos.toByteArray(); assertGreaterThan(2*128, data.length); // Corrupt the second block data[200] = (byte)~data[200]; ByteArrayInputStream bais = new ByteArrayInputStream(data); SecureInputStream sis = new SecureInputStream(bais); DataInputStream dis = new DataInputStream(sis); // Read the first block but make sure SecureInputStream // doesn't call refill which would trigger StreamCorruptedException dis.readFully(new byte[50]); try { dis.readFully(new byte[128]); fail("Should have thrown a StreamCorruptedException!"); } catch (StreamCorruptedException ignored) {} } }