package org.bouncycastle.jsse.provider.test;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.util.concurrent.Callable;
import org.bouncycastle.util.Strings;
import junit.framework.Assert;
class TestProtocolUtil
{
public interface BlockingCallable
extends Callable<Exception>
{
void await() throws InterruptedException;
}
public static class Task
implements Runnable
{
private final Callable<Exception> callable;
private Exception result = null;
public Task(Callable<Exception> callable)
{
this.callable = callable;
}
public Exception getResult()
{
return result;
}
public void run()
{
try
{
result = callable.call();
}
catch (Exception e)
{
result = e;
}
}
}
public static void runClientAndServer(BlockingCallable server, BlockingCallable client)
throws InterruptedException
{
TestProtocolUtil.Task serverTask = new TestProtocolUtil.Task(server);
Thread serverThread = new Thread(serverTask);
serverThread.start();
server.await();
TestProtocolUtil.Task clientTask = new TestProtocolUtil.Task(client);
Thread clientThread = new Thread(clientTask);
clientThread.start();
client.await();
serverThread.join();
clientThread.join();
Assert.assertNull(serverTask.getResult());
Assert.assertNull(clientTask.getResult());
}
public static void doClientProtocol(
Socket sock,
String text)
throws IOException
{
OutputStream out = sock.getOutputStream();
InputStream in = sock.getInputStream();
writeMessage(text, out);
String message = readMessage(in);
Assert.assertEquals("World", message);
}
public static void doServerProtocol(
Socket sock,
String text)
throws IOException
{
OutputStream out = sock.getOutputStream();
InputStream in = sock.getInputStream();
String message = readMessage(in);
writeMessage(text, out);
Assert.assertEquals("Hello", message);
}
private static void writeMessage(String text, OutputStream out)
throws IOException
{
out.write(Strings.toByteArray(text));
out.write('!');
}
private static String readMessage(InputStream in)
throws IOException
{
StringBuilder sb = new StringBuilder();
int ch;
while ((ch = in.read()) != '!')
{
sb.append((char)ch);
}
return sb.toString();
}
}