/*
* Copyright 2014, The Sporting Exchange Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.betfair.cougar.transport.nio;
import com.betfair.cougar.api.ExecutionContext;
import com.betfair.cougar.api.RequestUUID;
import com.betfair.cougar.api.export.Protocol;
import com.betfair.cougar.api.geolocation.GeoLocationDetails;
import com.betfair.cougar.api.security.IdentityChain;
import com.betfair.cougar.api.security.IdentityResolver;
import com.betfair.cougar.core.api.OperationBindingDescriptor;
import com.betfair.cougar.core.api.ServiceBindingDescriptor;
import com.betfair.cougar.core.api.ServiceVersion;
import com.betfair.cougar.core.api.ev.*;
import com.betfair.cougar.core.api.exception.CougarException;
import com.betfair.cougar.core.api.exception.CougarServiceException;
import com.betfair.cougar.core.api.exception.ServerFaultCode;
import com.betfair.cougar.core.api.security.IdentityResolverFactory;
import com.betfair.cougar.core.api.tracing.Tracer;
import com.betfair.cougar.core.api.transcription.Parameter;
import com.betfair.cougar.core.api.transcription.ParameterType;
import com.betfair.cougar.core.impl.DefaultTimeConstraints;
import com.betfair.cougar.core.impl.security.CommonNameCertInfoExtractor;
import com.betfair.cougar.core.impl.tracing.CompoundTracer;
import com.betfair.cougar.core.impl.transports.TransportRegistryImpl;
import com.betfair.cougar.netutil.nio.marshalling.DefaultExecutionContextResolverFactory;
import com.betfair.cougar.transport.api.DehydratedExecutionContextResolution;
import com.betfair.cougar.transport.api.RequestTimeResolver;
import com.betfair.cougar.transport.impl.DehydratedExecutionContextResolutionImpl;
import org.slf4j.LoggerFactory;
import com.betfair.cougar.netutil.nio.marshalling.DefaultSocketTimeResolver;
import com.betfair.cougar.netutil.nio.marshalling.SocketRMIMarshaller;
import com.betfair.cougar.netutil.nio.CougarProtocol;
import com.betfair.cougar.netutil.nio.NioLogger;
import com.betfair.cougar.netutil.nio.TlsNioConfig;
import com.betfair.cougar.netutil.nio.message.*;
import com.betfair.cougar.transport.api.protocol.CougarObjectInput;
import com.betfair.cougar.transport.api.protocol.CougarObjectOutput;
import com.betfair.cougar.netutil.nio.hessian.HessianObjectIOFactory;
import com.betfair.cougar.transport.api.protocol.socket.InvocationRequest;
import com.betfair.cougar.transport.api.protocol.socket.InvocationResponse;
import com.betfair.cougar.transport.api.protocol.socket.SocketBindingDescriptor;
import com.betfair.cougar.transport.api.protocol.socket.SocketOperationBindingDescriptor;
import com.betfair.cougar.transport.socket.SocketTransportCommandProcessor;
import com.betfair.cougar.util.RequestUUIDImpl;
import com.betfair.cougar.util.UUIDGeneratorImpl;
import com.betfair.cougar.util.geolocation.GeoIPLocator;
import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import org.mockito.Mockito;
import java.io.*;
import java.net.Socket;
import java.util.*;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import static com.betfair.cougar.netutil.nio.message.ProtocolMessage.ProtocolMessageType;
import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;
@RunWith(value = Parameterized.class)
public class ExecutionVenueNioServerTest {
private static class ByteArrayWrapper {
private byte[] array;
private ByteArrayWrapper(byte[] array) {
this.array = array;
}
public byte[] getArray() {
return array;
}
@Override
public boolean equals(Object obj) {
return Arrays.equals(array, ((ByteArrayWrapper) obj).array);
}
@Override
public int hashCode() {
return array != null ? Arrays.hashCode(array) : 0;
}
}
@Parameters
public static Collection<Object[]> data() {
String[] addresses = new String[] { "127.0.0.1" /*, "::1" */};
Set<ByteArrayWrapper> versionCombinations = new HashSet<ByteArrayWrapper>();
addVersions(versionCombinations, new byte[] {}, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_MIN_SUPPORTED);
for (ByteArrayWrapper b : versionCombinations) {
System.out.println("Version combo: "+Arrays.toString(b.getArray()));
}
List<Object[]> ret = new ArrayList<Object[]>();
for (String address : addresses) {
for (ByteArrayWrapper versions : versionCombinations) {
ret.add(new Object[] { address, versions.getArray() });
}
}
// Object[][] data = new Object[][] { { "127.0.0.1", new byte[] { 1, 2 } }, { "127.0.0.1", new byte[] { 1 } }, { "127.0.0.1", new byte[] { 2 } }
// /*, { "::1", new byte[] { 1, 2 } }, { "::1", new byte[] { 1 } }, { "::1", new byte[] { 2 } } */ };
return ret;
}
private static void addVersions(Set<ByteArrayWrapper> versionCombinations, byte[] prefix, byte nextVersion) {
if (nextVersion > CougarProtocol.TRANSPORT_PROTOCOL_VERSION_MAX_SUPPORTED) {
return;
}
versionCombinations.add(new ByteArrayWrapper(new byte[] { nextVersion }));
byte[] newPrefix = addToEnd(prefix, nextVersion);
versionCombinations.add(new ByteArrayWrapper(newPrefix));
for (byte b=(byte) (nextVersion+1); b<=CougarProtocol.TRANSPORT_PROTOCOL_VERSION_MAX_SUPPORTED; b++) {
addVersions(versionCombinations, newPrefix, b);
}
}
private static byte[] addToEnd(byte[] arr, byte toAdd) {
byte[] newArr = new byte[arr.length+1];
System.arraycopy(arr, 0, newArr, 0, arr.length);
newArr[newArr.length-1] = toAdd;
return newArr;
}
public static final String THE_ITALIAN_JOB = "you were only supposed to blow the ruddy doors off";
private static final OperationKey KEY = new OperationKey(new ServiceVersion("v1.0"), "UnitTestService",
"myUnitTestMethod");
private static final Parameter[] OP_PARAMS = new Parameter[] {
new Parameter("pass", new ParameterType(Boolean.class, null), true),
new Parameter("echoMe", new ParameterType(String.class, null), true) };
private static final ParameterType RETURN_PARAM_TYPE = new ParameterType(String.class, null);
private static final TimeConstraints TIME_CONSTRAINTS = DefaultTimeConstraints.NO_CONSTRAINTS;
public static final OperationDefinition OPERATION_DEFINITION = new OperationDefinition() {
@Override
public OperationKey getOperationKey() {
return KEY;
}
@Override
public Parameter[] getParameters() {
return OP_PARAMS;
}
@Override
public ParameterType getReturnType() {
return RETURN_PARAM_TYPE;
}
};
private String address;
TlsNioConfig cfg;
private ExecutionVenueNioServer server;
private Executor executor;
private Tracer tracer;
private SocketRMIMarshaller marshaller;
private ExecutionVenue ev;
private SocketTransportCommandProcessor cmdProcessor;
private byte[] clientConnectVersions;
private HessianObjectIOFactory ioFactory;
public ExecutionVenueNioServerTest(Object address, Object clientConnectVersions) {
this.address = (String)address;
this.clientConnectVersions = (byte[]) clientConnectVersions;
}
@BeforeClass
public static void setupStatic() {
RequestUUIDImpl.setGenerator(new UUIDGeneratorImpl());
}
@Before
public void startDummyEchoSocketServer() throws IOException {
ioFactory = new HessianObjectIOFactory(false);
tracer = new CompoundTracer();
cfg = new TlsNioConfig();
final NioLogger logger = new NioLogger("ALL");
cfg.setNioLogger(logger);
cfg.setListenAddress(address);
cfg.setListenPort(0);
cfg.setReuseAddress(true);
cfg.setTcpNoDelay(true);
cfg.setKeepAliveInterval(Integer.MAX_VALUE);
cfg.setKeepAliveTimeout(Integer.MAX_VALUE);
server = new ExecutionVenueNioServer();
server.setNioConfig(cfg);
cmdProcessor = new SocketTransportCommandProcessor();
cmdProcessor.setIdentityResolverFactory(new IdentityResolverFactory());
executor = new Executor() {
@Override
public void execute(Runnable command) {
Thread t = new Thread(command);
t.start();
}
};
DehydratedExecutionContextResolutionImpl contextResolution = new DehydratedExecutionContextResolutionImpl();
contextResolution.registerFactory(new DefaultExecutionContextResolverFactory(mock(GeoIPLocator.class), mock(RequestTimeResolver.class)));
contextResolution.init(false);
marshaller = new SocketRMIMarshaller(new CommonNameCertInfoExtractor(), contextResolution);
IdentityResolverFactory identityResolverFactory = new IdentityResolverFactory();
identityResolverFactory.setIdentityResolver(mock(IdentityResolver.class));
ev = new ExecutionVenue() {
@Override
public void registerOperation(String ns, OperationDefinition def, Executable executable, ExecutionTimingRecorder recorder, long maxExecutionTime) {
}
@Override
public OperationDefinition getOperationDefinition(OperationKey key) {
return OPERATION_DEFINITION;
}
@Override
public Set<OperationKey> getOperationKeys() {
return null;
}
@Override
public void execute(ExecutionContext ctx, OperationKey key, Object[] args, ExecutionObserver observer, TimeConstraints clientExpiryTime) {
if ((Boolean)args[0]) {
observer.onResult(new ExecutionResult(args[1]));
} else {
observer.onResult(new ExecutionResult(new CougarServiceException(ServerFaultCode.FrameworkError, THE_ITALIAN_JOB)));
}
}
@Override
public void execute(final ExecutionContext ctx, final OperationKey key, final Object[] args, final ExecutionObserver observer, final Executor executor, final TimeConstraints clientExpiryTime) {
executor.execute(new Runnable() {
@Override
public void run() {
execute(ctx, key, args, observer, clientExpiryTime);
}
});
}
@Override
public void setPreProcessors(List<ExecutionPreProcessor> preProcessorList) {
}
@Override
public void setPostProcessors(List<ExecutionPostProcessor> preProcessorList) {
}
};
cmdProcessor.setExecutor(executor);
cmdProcessor.setMarshaller(marshaller);
cmdProcessor.setExecutionVenue(ev);
cmdProcessor.setTracer(tracer);
ServiceBindingDescriptor desc = new SocketBindingDescriptor() {
@Override
public OperationBindingDescriptor[] getOperationBindings() {
return new OperationBindingDescriptor[] { new SocketOperationBindingDescriptor(KEY) };
}
@Override
public ServiceVersion getServiceVersion() {
return KEY.getVersion();
}
@Override
public String getServiceName() {
return KEY.getServiceName();
}
@Override
public Protocol getServiceProtocol() {
return Protocol.SOCKET;
}
};
cmdProcessor.bind(desc);
cmdProcessor.onCougarStart();
ExecutionVenueServerHandler handler = new ExecutionVenueServerHandler(new NioLogger("NONE"), cmdProcessor, new HessianObjectIOFactory(false));
server.setServerHandler(handler);
server.setSocketAcceptorProcessors(1);
server.setServerExecutor(Executors.newCachedThreadPool());
server.setTransportRegistry(new TransportRegistryImpl());
server.start();
server.setHealthState(true);
final IoSessionManager sessionManager = new IoSessionManager();
sessionManager.setNioLogger(logger);
sessionManager.setMaxTimeToWaitForRequestCompletion(5000);
server.setSessionManager(sessionManager);
}
@After
public void stopDummyEchoSocketServer() throws IOException {
server.stop();
}
@Test
public void testSocketRequest() throws Exception {
String expectedResult = "sweet";
InvocationResponse response = makeSocketRequest(123, true, expectedResult);
if (!response.isSuccess()) {
response.getException().printStackTrace();
}
assertTrue(response.isSuccess());
assertEquals(expectedResult, response.getResult());
}
@Test
public void testSocketRequestThrowingException() throws IOException {
InvocationResponse response = makeSocketRequest(2, false, "");
assertFalse(response.isSuccess());
CougarException exception = response.getException();
assertNotNull(exception);
assertEquals("Server fault received from remote server: FrameworkError(DSC-0002)", exception.getMessage());
assertEquals(THE_ITALIAN_JOB, exception.getCause().getMessage());
}
private Object readMessageFromInputStream(InputStream stream, byte communicationVersion) throws IOException {
DataInputStream dis = new DataInputStream(stream);
int messageLen = dis.readInt();
//Read the message type
ProtocolMessageType pm = ProtocolMessageType.getMessageByMessageType(dis.readByte());
switch (pm) {
case CONNECT:
int len = dis.readInt();
byte[] bytes = new byte[len];
dis.read(bytes);
return new ConnectMessage(bytes);
case REJECT:
RejectMessageReason reason = RejectMessageReason.getByReasonCode(dis.readByte());
byte versionCount = dis.readByte();
byte[] versions = new byte[versionCount];
return new RejectMessage(reason, versions);
case ACCEPT:
byte acceptedVersion = dis.readByte();
return new AcceptMessage(acceptedVersion);
case MESSAGE_RESPONSE:
case MESSAGE: // used for v1 clients
byte[] messageBody2 = new byte[messageLen - 9];
long correlationId2 = dis.readLong();
dis.read(messageBody2);
return new ResponseMessage(correlationId2, messageBody2);
}
return null;
}
private void writeMessageToOutputStream(Object message, OutputStream stream, byte communicationVersion) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream s = new DataOutputStream(baos);
if (message instanceof RequestMessage) {
RequestMessage messageBody = (RequestMessage) message;
s.writeInt(messageBody.getPayload().length + 9);
if (communicationVersion == CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC) {
s.writeByte(ProtocolMessageType.MESSAGE.getMessageType());
}
else {
s.writeByte(ProtocolMessageType.MESSAGE_REQUEST.getMessageType());
}
s.writeLong(messageBody.getCorrelationId());
s.write(messageBody.getPayload());
}
else if (message instanceof ConnectMessage) {
ConnectMessage cm = (ConnectMessage) message;
s.writeInt(cm.getApplicationVersions().length+2);
s.write(cm.getProtocolMessageType().getMessageType());
s.write((byte) cm.getApplicationVersions().length);
s.write(cm.getApplicationVersions());
}
s.flush();
stream.write(baos.toByteArray());
stream.flush();
}
@Test
public void testBadHandshake() throws IOException {
Socket connectedClient = new Socket(server.getBoundAddress(), server.getBoundPort());
OutputStream output = connectedClient.getOutputStream();
InputStream input = connectedClient.getInputStream();
byte communicationVersion = CougarProtocol.TRANSPORT_PROTOCOL_VERSION_MAX_SUPPORTED + 1;
//We, a nonsense client, only support version 3 of the protocol
writeMessageToOutputStream(new ConnectMessage(new byte[] {communicationVersion } ), output, communicationVersion);
ProtocolMessage message = (ProtocolMessage) readMessageFromInputStream(input, communicationVersion);
assertEquals("Incorrect application protocol version was accepted", ProtocolMessageType.REJECT, message.getProtocolMessageType());
}
private InvocationResponse makeSocketRequest(long correlationId, boolean pass, String echoMe) throws IOException {
Socket connectedClient = new Socket(server.getBoundAddress(), server.getBoundPort());
OutputStream output = connectedClient.getOutputStream();
InputStream input = connectedClient.getInputStream();
byte communicationVersion = CougarProtocol.TRANSPORT_PROTOCOL_VERSION_MIN_SUPPORTED; // handshake is set in stone
//start with handshake
writeMessageToOutputStream(new ConnectMessage(clientConnectVersions ), output, communicationVersion);
ProtocolMessage message = (ProtocolMessage) readMessageFromInputStream(input, communicationVersion);
assertEquals("Handshake was incorrect", ProtocolMessageType.ACCEPT, message.getProtocolMessageType());
communicationVersion = ((AcceptMessage) message).getAcceptedVersion();
//Now on to the message providing we handshook correctly
//Construct the byte stream to be sent, starts with correlation id, then the marshalled request
ByteArrayOutputStream baos = new ByteArrayOutputStream();
CougarObjectOutput out = ioFactory.newCougarObjectOutput(baos, communicationVersion);
Object[] args = { pass, echoMe};
marshaller.writeInvocationRequest(createRequest(args), out,null, null,communicationVersion);
out.flush();
final byte[] bytes = baos.toByteArray();
writeMessageToOutputStream(new RequestMessage(correlationId, bytes), output, communicationVersion);
ResponseMessage response = (ResponseMessage) readMessageFromInputStream(input, communicationVersion);
ByteArrayInputStream bais = new ByteArrayInputStream(response.getPayload());
CougarObjectInput dis = ioFactory.newCougarObjectInput(bais, communicationVersion);
assertEquals(correlationId, response.getCorrelationId());
return marshaller.readInvocationResponse(RETURN_PARAM_TYPE, dis);
}
public InvocationRequest createRequest(final Object[] args) {
return new InvocationRequest() {
@Override
public Object[] getArgs() {
return args;
}
@Override
public ExecutionContext getExecutionContext() {
return new ExecutionContext() {
@Override
public GeoLocationDetails getLocation() {
return new GeoLocationDetails() {
@Override
public String getRemoteAddr() {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
@Override
public String getInferredCountry() {
return null;
}
@Override
public List<String> getResolvedAddresses() {
return Collections.singletonList("5.1.8.6");
}
@Override
public String getCountry() {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
@Override
public boolean isLowConfidenceGeoLocation() {
return false; //To change body of implemented methods use File | Settings | File Templates.
}
@Override
public String getLocation() {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
};
}
@Override
public IdentityChain getIdentity() {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
@Override
public RequestUUID getRequestUUID() {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
@Override
public Date getReceivedTime() {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
@Override
public Date getRequestTime() {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
@Override
public boolean traceLoggingEnabled() {
return false; //To change body of implemented methods use File | Settings | File Templates.
}
@Override
public int getTransportSecurityStrengthFactor() {
return 0; //To change body of implemented methods use File | Settings | File Templates.
}
@Override
public boolean isTransportSecure() {
return false; //To change body of implemented methods use File | Settings | File Templates.
}
};
}
@Override
public OperationKey getOperationKey() {
return KEY;
}
@Override
public Parameter[] getParameters() {
return OP_PARAMS;
}
@Override
public TimeConstraints getTimeConstraints() {
return TIME_CONSTRAINTS;
}
};
}
}