/**
* Copyright 2015 Palantir Technologies, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.palantir.giraffe.internal;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.fail;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import com.palantir.giraffe.internal.SharedByteArrayStream.SharedInputStream;
import com.palantir.giraffe.internal.SharedByteArrayStream.SharedOutputStream;
/**
* Tests thread-safety of {@link SharedByteArrayStream}.
*
* @author bkeyes
*/
public class SharedByteArrayStreamConcurrencyTest {
/**
* Number of test iterations.
*/
private static final int ITERATIONS = 15;
/**
* Amount of data to read/write in each iteration.
*/
private static final int DATA_SIZE = 8192;
/**
* Number of chunks to split data into for writing.
*/
private static final int DATA_CHUNKS = 64;
private Random rand;
private ExecutorService executor;
private CyclicBarrier barrier;
@Before
public void setup() {
rand = new Random();
executor = Executors.newFixedThreadPool(2);
barrier = new CyclicBarrier(2);
}
@After
public void teardown() {
executor.shutdown();
}
@Test
public void concurrentReadAndWrite() throws InterruptedException {
for (int i = 0; i < ITERATIONS; i++) {
SharedByteArrayStream stream = new SharedByteArrayStream();
SharedInputStream is = stream.getInputStream();
SharedOutputStream os = stream.getOutputStream();
Future<byte[]> writeFuture = executor.submit(new WriteAction(os));
Future<byte[]> readFuture = executor.submit(new ReadAction(is));
byte[] expected = null;
try {
expected = writeFuture.get(30, TimeUnit.SECONDS);
} catch (ExecutionException e) {
throw new AssertionError("unexpected exception", e.getCause());
} catch (TimeoutException e) {
fail("timeout waiting for write action");
}
byte[] actual = null;
try {
actual = readFuture.get(10, TimeUnit.SECONDS);
} catch (ExecutionException e) {
throw new AssertionError("unexpected exception", e.getCause());
} catch (TimeoutException e) {
fail("timeout waiting for read action");
}
assertArrayEquals("incorrect data", expected, actual);
}
}
private final class WriteAction implements Callable<byte[]> {
private final OutputStream os;
private final byte[] data;
private final int[] chunkSizes;
private int index = 0;
WriteAction(OutputStream os) {
this.os = os;
data = new byte[DATA_SIZE];
rand.nextBytes(data);
chunkSizes = getChunkSizes();
}
@Override
public byte[] call() throws Exception {
barrier.await();
for (int i = 0; i < chunkSizes.length; i++) {
os.write(data, index, chunkSizes[i]);
index += chunkSizes[i];
if (rand.nextDouble() < 0.40) {
Thread.sleep(1);
}
}
os.close();
return data;
}
}
private final class ReadAction implements Callable<byte[]> {
private final InputStream is;
private final byte[] data;
private int index = 0;
public ReadAction(InputStream is) {
this.is = is;
data = new byte[DATA_SIZE];
}
@Override
public byte[] call() throws Exception {
barrier.await();
while (index < DATA_SIZE) {
int r = is.read(data, index, DATA_SIZE - index);
if (r == -1) {
break;
}
index += r;
if (rand.nextDouble() < 0.40) {
Thread.sleep(1);
}
}
is.close();
return data;
}
}
private int[] getChunkSizes() {
int total = 0;
int[] sizes = new int[DATA_CHUNKS];
for (int i = 0; i < sizes.length - 1; i++) {
int remaining = sizes.length - i;
int size = 1;
if (remaining < DATA_SIZE - total) {
double mean = ((double) DATA_SIZE - total) / remaining;
double stdDev = mean / 4;
size = (int) Math.round(rand.nextGaussian() * stdDev + mean);
size = Math.max(1, size);
if (total + size > DATA_SIZE) {
size = DATA_SIZE - total - remaining;
}
}
sizes[i] = size;
total += size;
}
sizes[sizes.length - 1] = DATA_SIZE - total;
return sizes;
}
}