package htsjdk.samtools;
import htsjdk.samtools.cram.CramSerilization;
import htsjdk.samtools.cram.build.CramIO;
import htsjdk.samtools.cram.structure.Container;
import htsjdk.samtools.cram.structure.ContainerIO;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import net.sf.cram.ref.ReferenceSource;
public class CRAMContainerAsynchWriter extends CRAMContainerStreamWriter {
private ThreadPoolExecutor es;
private long batchCounter = 0;
private PriorityBlockingQueue<Batch> batchResultQueue;
private long indexingCounter = 0;
private BlockingQueue<Runnable> jobQueue;
private List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
public CRAMContainerAsynchWriter(OutputStream outputStream, OutputStream indexStream, ReferenceSource source,
SAMFileHeader samFileHeader, String cramId, int threadPoolSize) {
super(outputStream, indexStream, source, samFileHeader, cramId);
if (threadPoolSize < 1)
throw new IllegalArgumentException("Need at least 1 worker thread for asynch CRAM writing.");
int maxJobs = threadPoolSize * 2;
jobQueue = new ArrayBlockingQueue<Runnable>(maxJobs);
// given the rejected execution policy an extra slot needed:
batchResultQueue = new PriorityBlockingQueue<CRAMContainerAsynchWriter.Batch>(maxJobs + 1);
System.out.println("Starting thread pool max size: " + threadPoolSize);
es = new ThreadPoolExecutor(1, threadPoolSize, 10, TimeUnit.SECONDS, jobQueue);
es.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
postProcessorThread = new Thread(postProcessor);
postProcessorThread.start();
}
private void exceptionInWorker(Throwable t) {
try {
t.printStackTrace();
es.shutdown();
batchResultQueue.clear();
jobQueue.clear();
exceptions.add(t);
} catch (Exception e) {
// can't do much but whine:
e.printStackTrace();
}
}
@Override
protected void flushContainer() throws IllegalArgumentException, IllegalAccessException, IOException {
if (!exceptions.isEmpty()) {
throw new RuntimeException(exceptions.get(0));
}
if (samRecords.isEmpty())
return;
Batch batch = new Batch(Arrays.asList(samRecords.toArray(new SAMRecord[samRecords.size()])));
CRAMContainerAsynchWriter.this.samRecords.clear();
es.execute(batch);
System.out.printf("Convert jobs: %d, write jobs: %d.\n", jobQueue.size(), batchResultQueue.size());
}
@Override
public void finish(boolean writeEOFContainer) {
if (!exceptions.isEmpty()) {
throw new RuntimeException(exceptions.get(0));
}
try {
flushContainer();
while (!jobQueue.isEmpty()) {
Thread.sleep(1000);
}
System.out.println("job queue empty");
es.shutdown();
while (!es.isTerminated()) {
Thread.sleep(1000);
}
System.out.println("terminated");
while (!batchResultQueue.isEmpty()) {
Thread.sleep(1000);
}
System.out.println("result queue empty");
postProcessorThread.interrupt();
if (writeEOFContainer) {
CramIO.issueEOF(cramVersion, outputStream);
}
outputStream.flush();
if (indexer != null) {
indexer.finish();
}
outputStream.close();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private class Batch implements Runnable, Comparable<Batch> {
List<SAMRecord> records;
long ordinal = batchCounter++;
Container container;
byte[] bytes;
public Batch(List<SAMRecord> records) {
this.records = records;
}
@Override
public void run() {
try {
container = CramSerilization.convert(records, samFileHeader, source, lossyOptions);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ContainerIO.writeContainer(cramVersion, container, baos);
bytes = baos.toByteArray();
batchResultQueue.put(this);
} catch (Exception e) {
exceptionInWorker(e);
}
}
@Override
public int compareTo(Batch o) {
return (int) (ordinal - o.ordinal);
}
}
private Runnable postProcessor = new Runnable() {
@Override
public void run() {
try {
Batch batch = null;
while (!Thread.interrupted()) {
batch = batchResultQueue.peek();
if (batch == null || batch.ordinal > indexingCounter) {
try {
Thread.sleep(200);
} catch (InterruptedException e) {
break;
}
continue;
}
batch = batchResultQueue.take();
if (batch.ordinal != indexingCounter)
throw new RuntimeException("Batch out of order.");
indexingCounter++;
batch.container.offset = offset;
offset += batch.bytes.length;
outputStream.write(batch.bytes);
if (indexer != null)
indexer.processContainer(batch.container, ValidationStringency.SILENT);
}
} catch (Exception e) {
exceptionInWorker(e);
}
}
};
private Thread postProcessorThread;
}