/*
* JBoss, Home of Professional Open Source.
*
* Copyright 2011 Red Hat, Inc. and/or its affiliates, and individual
* contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.xnio.mock;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.LockSupport;
import org.xnio.Buffers;
import org.xnio.ChannelListener;
import org.xnio.ChannelListener.Setter;
import org.xnio.IoUtils;
import org.xnio.Option;
import org.xnio.OptionMap;
import org.xnio.XnioExecutor;
import org.xnio.XnioIoThread;
import org.xnio.XnioWorker;
import org.xnio.channels.Channels;
import org.xnio.channels.ConnectedStreamChannel;
import org.xnio.channels.StreamSinkChannel;
import org.xnio.channels.StreamSourceChannel;
/**
* Mock of a connected stream channel.<p>
* This channel mock will store everything that is written to it for later comparison, and allows feeding of bytes for
* reading.
*
* @author <a href="mailto:flavia.rainone@jboss.com">Flavia Rainone</a>
*/
public class ConnectedStreamChannelMock implements ConnectedStreamChannel, StreamSourceChannel, StreamSinkChannel, Mock{
// written stuff will be copied to this buffer
private ByteBuffer writeBuffer = ByteBuffer.allocate(1000);
// read stuff will be taken from this buffer
private ByteBuffer readBuffer = ByteBuffer.allocate(10000);
// read stuff can only be read if read is enabled
private boolean readEnabled;
// can only write when write is enabled
private boolean writeEnabled = true;
// indicates if this channel is closed
private boolean closed = false;
private boolean checkClosed = true;
private boolean writeResumed = false;
private boolean writeAwaken = false;
private boolean readAwaken = false;
private boolean readResumed = false;
private boolean readsDown = false;
private boolean writesDown = false;
private boolean allowShutdownWrites = true;
private boolean flushed = true;
private boolean flushEnabled = true;
private boolean eof = false;
private XnioWorker worker = new XnioWorkerMock();
private XnioIoThread executor = new XnioIoThreadMock(null);
private Thread readWaiter;
private Thread writeWaiter;
private ChannelListener<? super ConnectedStreamChannel> readListener;
private ChannelListener<? super ConnectedStreamChannel> writeListener;
private ChannelListener<? super ConnectedStreamChannel> closeListener;
private String info = null; // any extra information regarding this channel used by tests
// listener setters
private final ChannelListener.Setter<ConnectedStreamChannel> readListenerSetter = new ChannelListener.Setter<ConnectedStreamChannel>() {
@Override
public void set(ChannelListener<? super ConnectedStreamChannel> listener) {
readListener = listener;
}
};
private final ChannelListener.Setter<ConnectedStreamChannel> writeListenerSetter = new ChannelListener.Setter<ConnectedStreamChannel>() {
@Override
public void set(ChannelListener<? super ConnectedStreamChannel> listener) {
writeListener = listener;
}
};
private final ChannelListener.Setter<ConnectedStreamChannel> closeListenerSetter = new ChannelListener.Setter<ConnectedStreamChannel>() {
@Override
public void set(ChannelListener<? super ConnectedStreamChannel> listener) {
closeListener = listener;
}
};
/**
* Feeds {@code readData} to read clients.
* @param readData data that will be available for reading on this channel mock
*/
public void setReadData(String... readData) {
final Thread waiter;
synchronized (this) {
int totalLength = 0;
for (String data: readData) {
totalLength += data.length();
}
int position = readBuffer.position();
boolean resetPosition = false;
if (!readBuffer.hasRemaining()) {
readBuffer.compact();
} else if(readBuffer.position() > 0 || readBuffer.limit() != readBuffer.capacity()) {
if (readBuffer.capacity() - readBuffer.limit() < totalLength) {
if (readBuffer.position() > 0 && readBuffer.capacity() - readBuffer.limit() + readBuffer.position() >= totalLength) {
readBuffer.compact();
}
throw new RuntimeException("ReadBuffer is full - not enough space to add more read data");
}
int limit = readBuffer.limit();
readBuffer.position(limit);
readBuffer.limit(limit += totalLength);
resetPosition = true;
}
for (String data: readData) {
try {
readBuffer.put(data.getBytes("UTF-8"));
} catch (UnsupportedEncodingException e) {
throw new RuntimeException(e);
}
}
readBuffer.flip();
if (resetPosition) {
readBuffer.position(position);
}
if (readWaiter == null || totalLength == 0 || !readEnabled) {
return;
}
waiter = readWaiter;
readWaiter = null;
}
LockSupport.unpark(waiter);
}
/**
* Feeds {@code readData} to read clients.
* @param readData data that will be available for reading on this channel mock
*/
public void setReadDataWithLength(String... readData) {
final Thread waiter;
synchronized (this) {
if (eof == true) {
throw new IllegalStateException("Cannot add read data once eof is set");
}
int totalLength = 0;
for (String data: readData) {
totalLength += data.length();
}
int position = readBuffer.position();
boolean resetPosition = false;
if (!readBuffer.hasRemaining()) {
readBuffer.compact();
} else if(readBuffer.position() > 0 || readBuffer.limit() != readBuffer.capacity()) {
if (readBuffer.capacity() - readBuffer.limit() + 4 < totalLength) {
if (readBuffer.position() > 0 && readBuffer.capacity() - readBuffer.limit() + readBuffer.position() + 4 >= totalLength) {
readBuffer.compact();
}
throw new RuntimeException("ReadBuffer is full - not enough space to add more read data");
}
int limit = readBuffer.limit();
readBuffer.position(limit);
readBuffer.limit(limit += totalLength + 4);
resetPosition = true;
}
readBuffer.putInt(totalLength);
for (String data: readData) {
try {
readBuffer.put(data.getBytes("UTF-8"));
} catch (UnsupportedEncodingException e) {
throw new RuntimeException(e);
}
}
readBuffer.flip();
if (resetPosition) {
readBuffer.position(position);
}
if (readWaiter == null || totalLength == 0 || !readEnabled) {
return;
}
waiter = readWaiter;
}
LockSupport.unpark(waiter);
}
/**
* Feeds {@code readData} to read clients.
* @param readData data that will be available for reading on this channel mock
*/
public void setReadDataWithLength(int length, String... readData) {
final Thread waiter;
synchronized (this) {
if (eof == true) {
throw new IllegalStateException("Cannot add read data once eof is set");
}
int totalLength = 0;
for (String data: readData) {
totalLength += data.length();
}
int position = readBuffer.position();
boolean resetPosition = false;
if (!readBuffer.hasRemaining()) {
readBuffer.compact();
} else if(readBuffer.position() > 0 || readBuffer.limit() != readBuffer.capacity()) {
if (readBuffer.capacity() - readBuffer.limit() + 4 < totalLength) {
if (readBuffer.position() > 0 && readBuffer.capacity() - readBuffer.limit() + readBuffer.position() + 4 >= totalLength) {
readBuffer.compact();
}
throw new RuntimeException("ReadBuffer is full - not enough space to add more read data");
}
int limit = readBuffer.limit();
readBuffer.position(limit);
readBuffer.limit(limit += totalLength + 4);
resetPosition = true;
}
readBuffer.putInt(length);
for (String data: readData) {
try {
readBuffer.put(data.getBytes("UTF-8"));
} catch (UnsupportedEncodingException e) {
throw new RuntimeException(e);
}
}
readBuffer.flip();
if (resetPosition) {
readBuffer.position(position);
}
if (readWaiter == null || totalLength == 0 || !readEnabled) {
return;
}
waiter = readWaiter;
}
LockSupport.unpark(waiter);
}
public void setEof() {
final Thread waiter;
synchronized (this) {
eof = true;
if (readWaiter == null || !readEnabled) {
return;
}
waiter = readWaiter;
}
LockSupport.unpark(waiter);
}
public void enableRead(boolean enable) {
final Thread waiter;
synchronized (this) {
readEnabled = enable;
if (readWaiter == null || !readEnabled || !((readBuffer.hasRemaining() && readBuffer.limit() != readBuffer.capacity()) || eof)) {
return;
}
waiter = readWaiter;
}
LockSupport.unpark(waiter);
}
public void enableWrite(boolean enable) {
final Thread waiter;
synchronized (this) {
writeEnabled = enable;
waiter = writeWaiter;
}
if (waiter != null) {
LockSupport.unpark(waiter);
}
}
public synchronized void enableClosedCheck(boolean enable) {
checkClosed = enable;
}
/**
* Returns all the bytes that have been written to this channel mock.
*
* @return the written bytes in the form of a UTF-8 string
*/
public String getWrittenText() {
if (writeBuffer.position() == 0 && writeBuffer.limit() == writeBuffer.capacity()) {
return "";
}
writeBuffer.flip();
return Buffers.getModifiedUtf8(writeBuffer);
}
public ByteBuffer getWrittenBytes() {
return writeBuffer;
}
@Override
public void close() throws IOException {
closed = true;
shutdownWrites();
shutdownReads();
}
@Override
public boolean isOpen() {
return !closed;
}
private OptionMap optionMap;
@Override
public boolean supportsOption(Option<?> option) {
return optionMap == null? false: optionMap.contains(option);
}
@Override
public <T> T getOption(Option<T> option) throws IOException {
return optionMap == null? null: optionMap.get(option);
}
@Override
public <T> T setOption(Option<T> option, T value) throws IllegalArgumentException, IOException {
final OptionMap.Builder optionMapBuilder = OptionMap.builder();
T previousValue = null;
if (optionMap != null) {
optionMapBuilder.addAll(optionMap);
previousValue = optionMap.get(option);
}
optionMapBuilder.set(option, value);
optionMap = optionMapBuilder.getMap();
return previousValue;
}
public void setOptionMap(OptionMap optionMap) {
this.optionMap = optionMap;
}
@Override
public OptionMap getOptionMap() {
return optionMap;
}
@Override
public void suspendReads() {
readAwaken = false;
readResumed = false;
}
@Override
public void resumeReads() {
readResumed = true;
}
@Override
public boolean isReadResumed() {
return readResumed;
}
@Override
public void shutdownReads() throws IOException {
readsDown = true;
return;
}
public boolean isShutdownReads() {
return readsDown;
}
/**
* This mock supports only one read thread waiting at most.
*/
@Override
public void awaitReadable() throws IOException {
synchronized(this) {
if (readWaiter != null) {
throw new IllegalStateException("ConnectedStreamChannelMock can be used only with one read waiter thread at most... there is already a waiting thread" + readWaiter);
}
if (((readBuffer.hasRemaining() && readBuffer.capacity() != readBuffer.limit()) || eof) && readEnabled) {
return;
}
readWaiter = Thread.currentThread();
}
LockSupport.park(readWaiter);
synchronized(this) {
readWaiter = null;
}
}
/**
* This mock supports only one read thread waiting at most.
*/
@Override
public void awaitReadable(long time, TimeUnit timeUnit) throws IOException {
synchronized (this) {
if (readWaiter != null) {
throw new IllegalStateException("ConnectedStreamChannelMock can be used only with one read waiter thread at most... there is already a waiting thread" + readWaiter);
}
if (((readBuffer.hasRemaining() && readBuffer.capacity() != readBuffer.limit()) || eof) && readEnabled) {
return;
}
readWaiter = Thread.currentThread();
}
// FIXME assertSame("ConnectedStreamChannelMock.awaitReadable(long, TimeUnit) can be used only with TimeUnit.NANOSECONDS", TimeUnit.MILLISECONDS, timeUnit);
LockSupport.parkNanos(readWaiter, timeUnit.toNanos(time));
synchronized (this) {
readWaiter = null;
}
}
@Override
@Deprecated
public XnioExecutor getReadThread() {
return executor;
}
@Override
public XnioIoThread getIoThread() {
return executor;
}
@Override
public void suspendWrites() {
writeAwaken = false;
writeResumed = false;
}
@Override
public void resumeWrites() {
writeResumed = true;
}
@Override
public boolean isWriteResumed() {
return writeResumed;
}
@Override
public synchronized void shutdownWrites() throws IOException {
if (!allowShutdownWrites) {
return;
}
writesDown = true;
final Thread waiter;
synchronized (this) {
eof = true;
if (readWaiter == null) {
return;
}
waiter = readWaiter;
}
LockSupport.unpark(waiter);
return;
}
public boolean isShutdownWrites() {
return writesDown;
}
@Override
public void awaitWritable() throws IOException {
synchronized(this) {
if (writeWaiter != null) {
throw new IllegalStateException("ConnectedStreamChannelMock can be used only with one write waiter thread at most... there is already a waiting thread" + writeWaiter);
}
if (writeEnabled) {
return;
}
writeWaiter = Thread.currentThread();
}
LockSupport.park(writeWaiter);
synchronized(this) {
writeWaiter = null;
}
}
@Override
public void awaitWritable(long time, TimeUnit timeUnit) throws IOException {
synchronized (this) {
if (writeWaiter != null) {
throw new IllegalStateException("ConnectedStreamChannelMock can be used only with one write waiter thread at most... there is already a waiting thread" + writeWaiter);
}
if (writeEnabled) {
return;
}
writeWaiter = Thread.currentThread();
}
// FIXME assertSame("ConnectedStreamChannelMock.awaitWritable(long, TimeUnit) can be used only with TimeUnit.NANOSECONDS", TimeUnit.NANOSECONDS, timeUnit);
LockSupport.parkNanos(writeWaiter, timeUnit.toNanos(time));
synchronized (this) {
writeWaiter = null;
}
}
@Override
@Deprecated
public XnioExecutor getWriteThread() {
return executor;
}
@Override
public synchronized boolean flush() throws IOException {
if (flushEnabled) {
flushed = true;
}
return flushed;
}
public boolean isFlushed() {
return flushed;
}
public synchronized void enableFlush(boolean enable) {
flushEnabled = enable;
}
@Override
public long transferFrom(FileChannel src, long position, long count) throws IOException {
if (writeEnabled) {
return src.transferTo(position, count, this);
}
return 0;
}
@Override
public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
if (writeEnabled) {
IoUtils.transfer(source, count, throughBuffer, this);
}
return 0;
}
@Override
public synchronized int write(ByteBuffer src) throws IOException {
if (closed && checkClosed) {
throw new ClosedChannelException();
}
if (writeEnabled) {
if (writeBuffer.limit() < writeBuffer.capacity()) {
writeBuffer.limit(writeBuffer.capacity());
}
int bytes = Buffers.copy(writeBuffer, src);
if (bytes > 0) {
flushed = false;
}
return bytes;
}
return 0;
}
@Override
public synchronized long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
if (closed && checkClosed) {
throw new ClosedChannelException();
}
if (writeEnabled) {
if (writeBuffer.limit() < writeBuffer.capacity()) {
writeBuffer.limit(writeBuffer.capacity());
}
int bytes = Buffers.copy(writeBuffer, srcs, offset, length);
if (bytes > 0) {
flushed = false;
}
return bytes;
}
return 0;
}
@Override
public synchronized long write(ByteBuffer[] srcs) throws IOException {
if (closed && checkClosed) {
throw new ClosedChannelException();
}
if (writeEnabled) {
if (writeBuffer.limit() < writeBuffer.capacity()) {
writeBuffer.limit(writeBuffer.capacity());
}
return Buffers.copy(writeBuffer, srcs, 0, srcs.length);
}
return 0;
}
@Override
public long transferTo(long position, long count, FileChannel target) throws IOException {
if (readEnabled) {
return target.transferFrom(this, position, count);
}
return 0;
}
@Override
public long transferTo(final long count, final ByteBuffer throughBuffer, final StreamSinkChannel target) throws IOException {
if (readEnabled) {
return IoUtils.transfer(this, count, throughBuffer, target);
}
return 0;
}
@Override
public synchronized int read(ByteBuffer dst) throws IOException {
if (closed && checkClosed) {
throw new ClosedChannelException();
}
if (readEnabled) {
try {
if ((!readBuffer.hasRemaining() || readBuffer.position() == 0 && readBuffer.limit() == readBuffer.capacity()) && eof) {
return -1;
}
if (readBuffer.limit() == readBuffer.capacity() && readBuffer.position() == 0) {
return 0;
}
return Buffers.copy(dst, readBuffer);
} catch (RuntimeException e) {
System.out.println("Got exception at attempt of copying contents of dst "+ dst.remaining() + " into read buffer " + readBuffer.remaining());
throw e;
}
}
return 0;
}
@Override
public synchronized long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
if (closed && checkClosed) {
throw new ClosedChannelException();
}
if (readEnabled) {
if ((!readBuffer.hasRemaining() || readBuffer.position() == 0 && readBuffer.limit() == readBuffer.capacity()) && eof) {
return -1;
}
if (readBuffer.limit() == readBuffer.capacity() && readBuffer.position() == 0) {
return 0;
}
return Buffers.copy(dsts, offset, length, readBuffer);
}
return 0;
}
public synchronized boolean allReadDataConsumed() {
return readBuffer.position() == readBuffer.limit();
}
@Override
public synchronized long read(ByteBuffer[] dsts) throws IOException {
if (closed && checkClosed) {
throw new ClosedChannelException();
}
if (readEnabled) {
if ((!readBuffer.hasRemaining() || readBuffer.position() == 0 && readBuffer.limit() == readBuffer.capacity()) && eof) {
return -1;
}
if (readBuffer.limit() == readBuffer.capacity() && readBuffer.position() == 0) {
return 0;
}
return Buffers.copy(dsts, 0, dsts.length, readBuffer);
}
return 0;
}
private SocketAddress peerAddress;
@Override
public SocketAddress getPeerAddress() {
return peerAddress;
}
@SuppressWarnings("unchecked")
@Override
public <A extends SocketAddress> A getPeerAddress(Class<A> type) {
if (type.isAssignableFrom(peerAddress.getClass())) {
return (A) peerAddress;
}
return null;
}
public void setPeerAddress(SocketAddress peerAddress) {
this.peerAddress = peerAddress;
}
private SocketAddress localAddress;
@Override
public SocketAddress getLocalAddress() {
return localAddress;
}
@SuppressWarnings("unchecked")
@Override
public <A extends SocketAddress> A getLocalAddress(Class<A> type) {
if (type.isAssignableFrom(localAddress.getClass())) {
return (A) localAddress;
}
return null;
}
public void setLocalAddress(SocketAddress localAddress) {
this.localAddress = localAddress;
}
@Override
public XnioWorker getWorker() {
return worker;
}
public void setWorker(XnioWorker worker) {
this.worker = worker;
}
@Override
public void wakeupReads() {
readAwaken = true;
readResumed = true;
}
public boolean isReadAwaken() {
return readAwaken;
}
@Override
public void wakeupWrites() {
writeAwaken = true;
writeResumed = true;
}
public boolean isWriteAwaken() {
return writeAwaken;
}
@Override
public Setter<? extends ConnectedStreamChannel> getReadSetter() {
return readListenerSetter;
}
@Override
public Setter<? extends ConnectedStreamChannel> getWriteSetter() {
return writeListenerSetter;
}
@Override
public Setter<? extends ConnectedStreamChannel> getCloseSetter() {
return closeListenerSetter;
}
@Override
public int writeFinal(ByteBuffer src) throws IOException {
return Channels.writeFinalBasic(this, src);
}
@Override
public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException {
return Channels.writeFinalBasic(this, srcs, offset, length);
}
@Override
public long writeFinal(ByteBuffer[] srcs) throws IOException {
return Channels.writeFinalBasic(this, srcs, 0, srcs.length);
}
public ChannelListener<? super ConnectedStreamChannel> getReadListener() {
return readListener;
}
public ChannelListener<? super ConnectedStreamChannel> getWriteListener() {
return writeListener;
}
public ChannelListener<? super ConnectedStreamChannel> getCloseListener() {
return closeListener;
}
@Override
public String getInfo() {
return info;
}
@Override
public void setInfo(String i) {
info = i;
}
}