package org.bitcoinj.protocols.channels; import org.bitcoinj.core.Coin; import org.bitcoinj.core.Sha256Hash; import org.bitcoinj.core.TransactionBroadcaster; import org.bitcoinj.core.Wallet; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.protobuf.ByteString; import org.bitcoin.paymentchannel.Protos; import javax.annotation.Nullable; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import static org.junit.Assert.assertEquals; /** * Various mock objects and utilities for testing payment channels code. */ public class ChannelTestUtils { public static class RecordingServerConnection implements PaymentChannelServer.ServerConnection { public BlockingQueue<Object> q = new LinkedBlockingQueue<Object>(); @Override public void sendToClient(Protos.TwoWayChannelMessage msg) { q.add(msg); } @Override public void destroyConnection(PaymentChannelCloseException.CloseReason reason) { q.add(reason); } @Override public void channelOpen(Sha256Hash contractHash) { q.add(contractHash); } @Override public ListenableFuture<ByteString> paymentIncrease(Coin by, Coin to, @Nullable ByteString info) { q.add(new UpdatePair(to, info)); return Futures.immediateFuture(ByteString.copyFromUtf8(by.toPlainString())); } public Protos.TwoWayChannelMessage getNextMsg() throws InterruptedException { return (Protos.TwoWayChannelMessage) q.take(); } public Protos.TwoWayChannelMessage checkNextMsg(Protos.TwoWayChannelMessage.MessageType expectedType) throws InterruptedException { Protos.TwoWayChannelMessage msg = getNextMsg(); assertEquals(expectedType, msg.getType()); return msg; } public void checkTotalPayment(Coin valueSoFar) throws InterruptedException { Coin lastSeen = ((UpdatePair) q.take()).amount; assertEquals(lastSeen, valueSoFar); } } public static class RecordingClientConnection implements PaymentChannelClient.ClientConnection { public BlockingQueue<Object> q = new LinkedBlockingQueue<Object>(); final static int IGNORE_EXPIRE = -1; private final int maxExpireTime; // An arbitrary sentinel object for equality testing. public static final Object CHANNEL_INITIATED = new Object(); public static final Object CHANNEL_OPEN = new Object(); public RecordingClientConnection(int maxExpireTime) { this.maxExpireTime = maxExpireTime; } @Override public void sendToServer(Protos.TwoWayChannelMessage msg) { q.add(msg); } @Override public void destroyConnection(PaymentChannelCloseException.CloseReason reason) { q.add(reason); } @Override public boolean acceptExpireTime(long expireTime) { return this.maxExpireTime == IGNORE_EXPIRE || expireTime <= maxExpireTime; } @Override public void channelOpen(boolean wasInitiated) { if (wasInitiated) q.add(CHANNEL_INITIATED); q.add(CHANNEL_OPEN); } public Protos.TwoWayChannelMessage getNextMsg() throws InterruptedException { return (Protos.TwoWayChannelMessage) q.take(); } public Protos.TwoWayChannelMessage checkNextMsg(Protos.TwoWayChannelMessage.MessageType expectedType) throws InterruptedException { Protos.TwoWayChannelMessage msg = getNextMsg(); assertEquals(expectedType, msg.getType()); return msg; } public void checkOpened() throws InterruptedException { assertEquals(CHANNEL_OPEN, q.take()); } public void checkInitiated() throws InterruptedException { assertEquals(CHANNEL_INITIATED, q.take()); checkOpened(); } } public static class RecordingPair { public PaymentChannelServer server; public RecordingServerConnection serverRecorder; public RecordingClientConnection clientRecorder; } public static RecordingPair makeRecorders(final Wallet serverWallet, final TransactionBroadcaster mockBroadcaster) { return makeRecorders(serverWallet, mockBroadcaster, RecordingClientConnection.IGNORE_EXPIRE); } public static RecordingPair makeRecorders(final Wallet serverWallet, final TransactionBroadcaster mockBroadcaster, int maxExpireTime) { RecordingPair pair = new RecordingPair(); pair.serverRecorder = new RecordingServerConnection(); pair.server = new PaymentChannelServer(mockBroadcaster, serverWallet, Coin.COIN, pair.serverRecorder); pair.clientRecorder = new RecordingClientConnection(maxExpireTime); return pair; } public static class UpdatePair { public Coin amount; public ByteString info; public UpdatePair(Coin amount, ByteString info) { this.amount = amount; this.info = info; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; UpdatePair that = (UpdatePair) o; if (amount != null ? !amount.equals(that.amount) : that.amount != null) return false; if (info != null ? !info.equals(that.info) : that.info != null) return false; return true; } @Override public int hashCode() { int result = amount != null ? amount.hashCode() : 0; result = 31 * result + (info != null ? info.hashCode() : 0); return result; } public void assertPair(Coin amount, ByteString info) { assertEquals(amount, this.amount); assertEquals(info, this.info); } } }