/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
package com.facebook.crypto.streams;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Random;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
public class TailInputStreamTest {
private TailInputStream mTailInputStream;
private InputStream mInputStream;
private byte[] mInputData;
private final int TAIL_LENGTH = 16;
@Before
public void setUp() {
mInputData = new byte[TAIL_LENGTH * 20 + 7];
Random random = new Random();
random.nextBytes(mInputData);
mInputStream = new ByteArrayInputStream(mInputData);
mTailInputStream = new TailInputStream(mInputStream, TAIL_LENGTH);
}
@Test
public void testReadInSmallIncrements() throws IOException {
byte[] temp = new byte[TAIL_LENGTH / 3];
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
readFully(mTailInputStream, temp, outputStream);
byte[] tail = mTailInputStream.getTail();
byte[] readData = outputStream.toByteArray();
TailBufferHelper.verifyDataAndTailMatch(mInputData, readData, tail, TAIL_LENGTH);
}
@Test
public void testReadInTagSizeIncrements() throws IOException {
byte[] temp = new byte[TAIL_LENGTH];
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
readFully(mTailInputStream, temp, outputStream);
byte[] tail = mTailInputStream.getTail();
byte[] readData = outputStream.toByteArray();
TailBufferHelper.verifyDataAndTailMatch(mInputData, readData, tail, TAIL_LENGTH);
}
@Test
public void testReadInLargeIncrements() throws IOException {
byte[] temp = new byte[TAIL_LENGTH * 2 + 3];
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
readFully(mTailInputStream, temp, outputStream);
byte[] tail = mTailInputStream.getTail();
byte[] readData = outputStream.toByteArray();
TailBufferHelper.verifyDataAndTailMatch(mInputData, readData, tail, TAIL_LENGTH);
}
@Test
public void testReadOneByteAtATime() throws IOException {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
int read;
while ((read = mTailInputStream.read()) != -1) {
outputStream.write(read);
}
byte[] tail = mTailInputStream.getTail();
byte[] readData = outputStream.toByteArray();
TailBufferHelper.verifyDataAndTailMatch(mInputData, readData, tail, TAIL_LENGTH);
}
@Test(expected = IOException.class)
public void throwsWhenTailNotSufficient() throws IOException {
byte[] data = new byte[TAIL_LENGTH - 1];
TailInputStream tailStream =
new TailInputStream(new ByteArrayInputStream(data), TAIL_LENGTH);
byte[] temp = new byte[TAIL_LENGTH * 2 + 3];
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
readFully(tailStream, temp, outputStream);
tailStream.getTail();
}
@Test
public void testBytesReturnedByUnderlyingStreamIsReduced() throws IOException {
InputStream inputStream = new ByteReducingInputStream(mInputStream, 2);
TailInputStream tailInputStream =
new TailInputStream(inputStream, TAIL_LENGTH);
byte[] temp = new byte[TAIL_LENGTH * 2 + 3];
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
readFully(tailInputStream, temp, outputStream);
byte[] tail = tailInputStream.getTail();
byte[] readData = outputStream.toByteArray();
TailBufferHelper.verifyDataAndTailMatch(mInputData, readData, tail, TAIL_LENGTH);
}
private void readFully(InputStream input,
byte[] tempBuffer,
ByteArrayOutputStream output) throws IOException {
int read = 0;
while ((read = input.read(tempBuffer)) != -1) {
Assert.assertTrue(read > 0);
output.write(tempBuffer, 0, read);
}
}
private static class ByteReducingInputStream extends FilterInputStream {
private final float mReduction;
protected ByteReducingInputStream(InputStream in, float reduction) {
super(in);
mReduction = reduction;
}
public int read(byte[] buffer, int offset, int count) throws IOException {
int newCount = (int) ((float) count / mReduction);
return in.read(buffer, offset, newCount);
}
}
}