package org.rascalmpl.repl;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
public class NotifieableInputStream extends InputStream {
private final ConcurrentLinkedQueue<Byte> queue;
private volatile boolean closed;
private volatile IOException toThrow;
private final Semaphore newData = new Semaphore(0);
private final Thread reader;
private final InputStream peekAt;
/**
* scan for certain bytes in the stream, and if they are found, call the callback function to see if it has to be swallowed.
*/
public NotifieableInputStream(final InputStream peekAt, final byte[] watchFor, final Function<Byte, Boolean> swallow) {
this.queue = new ConcurrentLinkedQueue<Byte>();
this.closed = false;
this.toThrow = null;
this.peekAt = peekAt;
this.reader = new Thread(new Runnable() {
@Override
public void run() {
try {
reading:
while (!closed) {
int b = peekAt.read();
if (b == -1) {
NotifieableInputStream.this.close();
return;
}
for (byte c: watchFor) {
if (b == c) {
if (swallow.apply((byte)b)) {
continue reading;
}
break;
}
}
queue.offer((byte)b);
newData.release();
}
}
catch(IOException e) {
if (!e.getMessage().contains("closed")) {
toThrow = e;
try {
NotifieableInputStream.this.close();
}
catch (IOException e1) {
}
}
}
}
});
reader.setName("InputStream scanner");
reader.setDaemon(true);
reader.start();
}
@Override
public int read(byte[] b) throws IOException {
return read(b, 0, b.length);
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
if (b == null) {
throw new NullPointerException();
} else if ((off < 0) || (off > b.length) || (len < 0) ||
((off + len) > b.length) || ((off + len) < 0)) {
throw new IndexOutOfBoundsException();
} else if (len == 0) {
return 0;
}
// we have to at least read one (so block until we can)
int atLeastOne = read();
if (atLeastOne == -1) {
return -1;
}
int index = off;
b[index++] = (byte) atLeastOne;
// now consume the rest of the available bytes
Byte current;
while ((current = queue.poll()) != null && (index < off + len)) {
b[index++] = current;
}
return index - off;
}
@Override
public int read() throws IOException {
Byte result = null;
while ((result = queue.poll()) == null) {
if (closed) {
return -1;
}
try {
newData.tryAcquire(10, TimeUnit.MILLISECONDS);
}
catch (InterruptedException e) {
return -1;
}
if (toThrow != null) {
IOException throwCopy = toThrow;
toThrow = null;
if (throwCopy != null) {
throw throwCopy;
}
}
}
return (result & 0xFF);
}
@Override
public void close() throws IOException {
closed = true;
peekAt.close();
}
}