/* * Copyright 2014, The Sporting Exchange Limited * Copyright 2015, Simon Matić Langford * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.betfair.cougar.netutil.nio.marshalling; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.anySet; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.lang.reflect.Array; import java.security.Principal; import java.security.cert.X509Certificate; import java.util.*; import java.util.Map.Entry; import com.betfair.cougar.api.DehydratedExecutionContext; import com.betfair.cougar.api.export.Protocol; import com.betfair.cougar.api.security.IdentityResolver; import com.betfair.cougar.core.api.builder.DehydratedExecutionContextBuilder; import com.betfair.cougar.core.api.builder.ExecutionContextBuilder; import com.betfair.cougar.core.api.ev.TimeConstraints; import com.betfair.cougar.core.api.exception.CougarException; import com.betfair.cougar.core.api.transcription.TranscribableParams; import com.betfair.cougar.core.impl.DefaultTimeConstraints; import com.betfair.cougar.core.impl.security.CommonNameCertInfoExtractor; import com.betfair.cougar.marshalling.impl.RandomException; import com.betfair.cougar.marshalling.impl.SimpleApplicationException; import com.betfair.cougar.marshalling.impl.SimpleExecutionContext; import com.betfair.cougar.marshalling.impl.SimpleGeoLocationDetails; import com.betfair.cougar.netutil.nio.CougarProtocol; import com.betfair.cougar.transport.api.*; import com.betfair.cougar.transport.impl.DehydratedExecutionContextResolutionImpl; import com.betfair.cougar.util.RequestUUIDImpl; import com.betfair.cougar.util.UUIDGeneratorImpl; import com.betfair.cougar.util.geolocation.RemoteAddressUtils; import com.google.common.collect.ImmutableMultiset; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import com.betfair.cougar.api.ExecutionContext; import com.betfair.cougar.api.ResponseCode; import com.betfair.cougar.api.fault.CougarApplicationException; import com.betfair.cougar.api.geolocation.GeoLocationDetails; import com.betfair.cougar.api.security.Credential; import com.betfair.cougar.api.security.Identity; import com.betfair.cougar.api.security.IdentityChain; import com.betfair.cougar.api.security.IdentityToken; import com.betfair.cougar.api.security.InvalidCredentialsException; import com.betfair.cougar.core.api.ServiceVersion; import com.betfair.cougar.core.api.ev.OperationKey; import com.betfair.cougar.core.api.exception.CougarFrameworkException; import com.betfair.cougar.core.api.transcription.Parameter; import com.betfair.cougar.core.api.transcription.ParameterType; import com.betfair.cougar.core.api.transcription.ParameterType.Type; import com.betfair.cougar.core.api.transcription.Transcribable; import com.betfair.cougar.core.api.transcription.TranscriptionInput; import com.betfair.cougar.core.api.transcription.TranscriptionOutput; import com.betfair.cougar.marshalling.impl.to.Cycle1; import com.betfair.cougar.marshalling.impl.to.Cycle2; import com.betfair.cougar.marshalling.impl.to.Foo; import com.betfair.cougar.marshalling.impl.to.FooDelegateImpl; import com.betfair.cougar.transport.api.protocol.CougarObjectInput; import com.betfair.cougar.transport.api.protocol.CougarObjectOutput; import com.betfair.cougar.netutil.nio.hessian.HessianObjectIOFactory; import com.betfair.cougar.transport.api.protocol.socket.InvocationRequest; import com.betfair.cougar.transport.api.protocol.socket.InvocationResponse; import com.betfair.cougar.util.geolocation.GeoIPLocator; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; /** * Unit test for the @see SocketRMIMarshallerTest class * */ @RunWith(Parameterized.class) public class SocketRMIMarshallerTest { private byte protocolVersion; public SocketRMIMarshallerTest(byte protocolVersion) { this.protocolVersion = protocolVersion; } @Parameterized.Parameters public static Collection<Object[]> params() { List<Object[]> ret = new ArrayList<Object[]>(); for (byte b=CougarProtocol.TRANSPORT_PROTOCOL_VERSION_MIN_SUPPORTED; b<=CougarProtocol.TRANSPORT_PROTOCOL_VERSION_MAX_SUPPORTED; b++) { ret.add(new Object[] {b}); } return ret; } @BeforeClass public static void setup() { RequestUUIDImpl.setGenerator(new UUIDGeneratorImpl()); } private class PrincipalImpl implements Principal { private String name; public PrincipalImpl(String name) { this.name = name; } @Override public String getName() { return name; } @Override public boolean equals(Object that) { return EqualsBuilder.reflectionEquals(this, that); } @Override public int hashCode() { return HashCodeBuilder.reflectionHashCode(this); } }; private class CredentialImpl implements Credential { private String name, value; public CredentialImpl(String name, String value) { this.name = name; this.value = value; } @Override public String getName() { return name; } @Override public Object getValue() { return value; } @Override public boolean equals(Object that) { return EqualsBuilder.reflectionEquals(this, that); } @Override public int hashCode() { return HashCodeBuilder.reflectionHashCode(this); } } private class IdentityImpl implements Identity { private Principal principal; private Credential credential; public IdentityImpl(Principal principal, Credential credential) { this.principal = principal; this.credential = credential; } @Override public Principal getPrincipal() { return principal; } @Override public Credential getCredential() { return credential; } @Override public boolean equals(Object that) { return EqualsBuilder.reflectionEquals(this, that); } @Override public int hashCode() { return HashCodeBuilder.reflectionHashCode(this); } } private class IdentityChainImpl implements IdentityChain { private List<Identity> identities; public IdentityChainImpl(List<Identity> identities) { this.identities = identities; } @Override public void addIdentity(Identity identity) { identities.add(identity); } @Override public List<Identity> getIdentities() { return identities; } @Override public <T extends Identity> List<T> getIdentities(Class<T> clazz) { return null; } @Override public boolean equals(Object that) { return EqualsBuilder.reflectionEquals(this, that); } @Override public int hashCode() { return HashCodeBuilder.reflectionHashCode(this); } } private class IdentityResolverImpl implements IdentityResolver { @Override // Convert identity chain into tokens public List<IdentityToken> tokenise(IdentityChain chain) { List<IdentityToken> tokens = new ArrayList<IdentityToken>(); if (chain != null) { for (Identity identity : chain.getIdentities()) { IdentityToken token = new IdentityToken( identity.getPrincipal().getName(), new StringBuilder( identity.getCredential().getName()) .append(",") .append(identity.getCredential().getValue()) .toString() ); tokens.add(token); } } return tokens; } @Override public void resolve(IdentityChain chain, DehydratedExecutionContext ctx) throws InvalidCredentialsException { for (final IdentityToken token : ctx.getIdentityTokens()) { Principal principal = new PrincipalImpl(token.getName()); Scanner scanner = new Scanner(token.getValue()).useDelimiter(","); Credential credential = new CredentialImpl(scanner.next(), scanner.next()); Identity identity = new IdentityImpl(principal, credential); chain.addIdentity(identity); } } } private static HessianObjectIOFactory ioFactory; private static DehydratedExecutionContextResolutionImpl contextResolution = new DehydratedExecutionContextResolutionImpl(); private SocketRMIMarshaller cut = new SocketRMIMarshaller(new CommonNameCertInfoExtractor(), contextResolution); private static ArgumentCaptor<SocketContextResolutionParams> socketContextResolutionParamsArgumentCaptor = ArgumentCaptor.forClass(SocketContextResolutionParams.class); @BeforeClass public static void staticBefore() { DehydratedExecutionContextResolver<SocketContextResolutionParams, Void> additionalParamsMock = mock(DehydratedExecutionContextResolver.class); when(additionalParamsMock.supportedComponents()).thenReturn(new DehydratedExecutionContextComponent[] { DehydratedExecutionContextComponent.ReceivedTime }); final ArgumentCaptor<DehydratedExecutionContextBuilder> builderArgumentCaptor = ArgumentCaptor.forClass(DehydratedExecutionContextBuilder.class); final ArgumentCaptor<Void> voidArgumentCaptor = ArgumentCaptor.forClass(Void.class); doAnswer(new Answer<Void>() { @Override public Void answer(InvocationOnMock invocation) throws Throwable { List<DehydratedExecutionContextBuilder> allBuilders = builderArgumentCaptor.getAllValues(); allBuilders.get(allBuilders.size()-1).setReceivedTime(new Date()); return null; } }).when(additionalParamsMock).resolve(socketContextResolutionParamsArgumentCaptor.capture(),voidArgumentCaptor.capture(),builderArgumentCaptor.capture()); doAnswer(new Answer<Void>() { @Override public Void answer(InvocationOnMock invocation) throws Throwable { Set<DehydratedExecutionContextComponent> components = (Set<DehydratedExecutionContextComponent>) invocation.getArguments()[0]; if (!components.contains(DehydratedExecutionContextComponent.ReceivedTime)) { throw new RuntimeException("I'm not handling what i want to!"); } return null; } }).when(additionalParamsMock).resolving(anySet()); DehydratedExecutionContextResolverFactory additionalParamsFactory = mock(DehydratedExecutionContextResolverFactory.class); when(additionalParamsFactory.resolvers(Protocol.SOCKET)).thenReturn(new DehydratedExecutionContextResolver[] { additionalParamsMock }); contextResolution.registerFactory(new DefaultExecutionContextResolverFactory(mock(GeoIPLocator.class), mock(RequestTimeResolver.class))); contextResolution.registerFactory(additionalParamsFactory); contextResolution.init(false); } @Before public void before() { ioFactory = new HessianObjectIOFactory(false); } @Test public void testGeoLocationMarshalling() throws IOException { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); //Only the resolved IP is relevant here SimpleGeoLocationDetails toMarshall = new SimpleGeoLocationDetails("1.2.3.4"); cut.writeGeoLocation(toMarshall, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); byte[] theBytes = outputStream.toByteArray(); assertNotNull(theBytes); GeoLocationParameters unMarshalled = cut.readGeoLocation(ioFactory.newCougarObjectInput(new ByteArrayInputStream(theBytes), protocolVersion), "10.20.30.40", protocolVersion); assertEquals("10.20.30.40", unMarshalled.getRemoteAddress()); assertEquals(RemoteAddressUtils.parse("1.2.3.4", "1.2.3.4," + RemoteAddressUtils.localAddressList), unMarshalled.getAddressList()); assertNull(unMarshalled.getInferredCountry()); } @Test public void testGeoLocationMarshallingWithInferredCountry() throws IOException { String inferredCountry = null; if (protocolVersion >= CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS) { inferredCountry = "JM"; } ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); //Only the resolved IP is relevant here SimpleGeoLocationDetails toMarshall = new SimpleGeoLocationDetails(Collections.singletonList("1.2.3.4"), inferredCountry); cut.writeGeoLocation(toMarshall, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); byte[] theBytes = outputStream.toByteArray(); assertNotNull(theBytes); GeoLocationParameters unMarshalled = cut.readGeoLocation(ioFactory.newCougarObjectInput(new ByteArrayInputStream(theBytes), protocolVersion), "10.20.30.40", protocolVersion); assertEquals("10.20.30.40", unMarshalled.getRemoteAddress()); assertEquals(RemoteAddressUtils.parse("1.2.3.4", "1.2.3.4," + RemoteAddressUtils.localAddressList), unMarshalled.getAddressList()); if (protocolVersion >= CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS) { assertEquals("JM",unMarshalled.getInferredCountry()); } else { assertNull(unMarshalled.getInferredCountry()); } } @Test public void testGeoLocationMarshallingWithNullResolvedIP() throws IOException { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); //Only the resolved IP is relevant here SimpleGeoLocationDetails toMarshall = new SimpleGeoLocationDetails((List)null); cut.writeGeoLocation(toMarshall, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); byte[] theBytes = outputStream.toByteArray(); assertNotNull(theBytes); GeoLocationParameters unMarshalled = cut.readGeoLocation(ioFactory.newCougarObjectInput(new ByteArrayInputStream(theBytes), protocolVersion), "10.20.30.40", protocolVersion); assertEquals("10.20.30.40", unMarshalled.getRemoteAddress()); assertEquals(RemoteAddressUtils.parse(RemoteAddressUtils.localAddressList, RemoteAddressUtils.localAddressList), unMarshalled.getAddressList()); assertNull(unMarshalled.getInferredCountry()); } @Test public void testGeoLocationMarshallingWithEmptyResolvedIP() throws IOException { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); //Only the resolved IP is relevant here List empty = Collections.emptyList(); SimpleGeoLocationDetails toMarshall = new SimpleGeoLocationDetails(empty); cut.writeGeoLocation(toMarshall, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); byte[] theBytes = outputStream.toByteArray(); assertNotNull(theBytes); GeoLocationParameters unMarshalled = cut.readGeoLocation(ioFactory.newCougarObjectInput(new ByteArrayInputStream(theBytes), protocolVersion), "10.20.30.40", protocolVersion); assertEquals("10.20.30.40", unMarshalled.getRemoteAddress()); assertEquals(RemoteAddressUtils.parse(RemoteAddressUtils.localAddressList, RemoteAddressUtils.localAddressList), unMarshalled.getAddressList()); assertNull(unMarshalled.getInferredCountry()); } @Test public void testGeoLocationMarshallingWithMultipleResolvedIP() throws IOException { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); //Only the resolved IP is relevant here List empty = Collections.emptyList(); SimpleGeoLocationDetails toMarshall = new SimpleGeoLocationDetails(Arrays.asList("127.0.0.1","128.0.0.1")); cut.writeGeoLocation(toMarshall, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); byte[] theBytes = outputStream.toByteArray(); assertNotNull(theBytes); final GeoLocationDetails gld = mock(GeoLocationDetails.class); GeoLocationParameters unMarshalled = cut.readGeoLocation(ioFactory.newCougarObjectInput(new ByteArrayInputStream(theBytes), protocolVersion), "10.20.30.40", protocolVersion); assertEquals("10.20.30.40", unMarshalled.getRemoteAddress()); assertEquals(RemoteAddressUtils.parse("127.0.0.1", "127.0.0.1,128.0.0.1," + RemoteAddressUtils.localAddressList), unMarshalled.getAddressList()); assertNull(unMarshalled.getInferredCountry()); } @Test(expected=IOException.class) public void testGeoUnmarshallingWithEmptyInputStream() throws IOException { //When we attempt to unmarshall an empty stream, an EOFException is thrown cut.readGeoLocation(ioFactory.newCougarObjectInput(new ByteArrayInputStream(new byte[] {}), protocolVersion), "10.20.30.40", protocolVersion); } @Test public void testRequestMarshalling() throws IOException { final ExecutionContext ctx = new SimpleExecutionContext(); final OperationKey key = new OperationKey(new ServiceVersion("v1.0"), "UnitTestService", "myUnitTestMethod"); final Parameter[] params = new Parameter[] { new Parameter("param1", new ParameterType(String.class, null), true) }; final TimeConstraints timeConstraints = DefaultTimeConstraints.NO_CONSTRAINTS; final Object[] args = new Object[] { "hello" }; ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); InvocationRequest request = new InvocationRequest() { @Override public Object[] getArgs() { return args; } @Override public ExecutionContext getExecutionContext() { return ctx; } @Override public OperationKey getOperationKey() { return key; } @Override public Parameter[] getParameters() { return params; } @Override public TimeConstraints getTimeConstraints() { return timeConstraints; } }; Map<String,String> additionalParams = new HashMap<>(); additionalParams.put("paramA","valueA"); cut.writeInvocationRequest(request, cougarObjectOutput, identityResolver, additionalParams, protocolVersion); cougarObjectOutput.flush(); cougarObjectOutput.close(); //String resolvedAddresses = RemoteAddressUtils.externaliseWithLocalAddresses(ctx.getLocation().getResolvedAddresses()); //when(geoIpLocator.getGeoLocation("10.20.30.40", RemoteAddressUtils.parse("10.20.30.40", resolvedAddresses), null)).thenReturn(ctx.getLocation()); CougarObjectInput in = ioFactory.newCougarObjectInput(new ByteArrayInputStream(outputStream.toByteArray()), protocolVersion); DehydratedExecutionContext actualContext = cut.readExecutionContext(in, "10.20.30.40", new X509Certificate[0], 0, protocolVersion); OperationKey actualKey = cut.readOperationKey(in); Object[] actualArgs = cut.readArgs(params, in); assertNotNull(actualContext); assertEquals(key, actualKey); assertArrayEquals(args, actualArgs); List<SocketContextResolutionParams> allSocketParams = socketContextResolutionParamsArgumentCaptor.getAllValues(); Map<String,String> reslvedAdditionalParams = allSocketParams.get(allSocketParams.size()-1).getAdditionalData(); if (protocolVersion >= CougarProtocol.TRANSPORT_PROTOCOL_VERSION_COMPOUND_REQUEST_UUID) { assertEquals(1,reslvedAdditionalParams.size()); assertEquals("valueA",reslvedAdditionalParams.get("paramA")); } else { assertEquals(0,reslvedAdditionalParams.size()); } } @Test public void testResponseMarshallingWithReturnedValue() throws IOException { ParameterType resultType = new ParameterType(String.class, null); InvocationResponse response = new SocketRMIMarshaller.InvocationResponseImpl(new String("result!"), null); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); cut.writeInvocationResponse(response, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); InvocationResponse actualResponse = cut.readInvocationResponse(resultType, ioFactory.newCougarObjectInput(new ByteArrayInputStream(outputStream.toByteArray()), protocolVersion)); assertTrue(actualResponse.isSuccess()); assertEquals(response.getResult(), actualResponse.getResult()); assertNull(actualResponse.getException()); } @Test /** * test the serialisation and deserialisation of response object where the response is created using a delegate */ public void testResponseMarshallingWithReturnedDelegate() throws Exception { ParameterType resultType = new ParameterType(Foo.class,null); Foo foo = new Foo(new FooDelegateImpl("foo")); InvocationResponse response = new SocketRMIMarshaller.InvocationResponseImpl(foo,null); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); cut.writeInvocationResponse(response, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); InvocationResponse actualResponse = cut.readInvocationResponse(resultType, ioFactory.newCougarObjectInput(new ByteArrayInputStream(outputStream.toByteArray()), protocolVersion)); assertTrue(actualResponse.isSuccess()); Object responseObject = removeDelegates(response.getResult()); assertEquals(responseObject, actualResponse.getResult()); assertNull(actualResponse.getException()); } @Test /** * test the serialisation and deserialisation of response object where the response graph has cycles */ public void testResponseMarshallingWithCycles() throws Exception { ParameterType resultType = new ParameterType(Cycle1.class,null); Cycle1 cycle1 = new Cycle1(); Cycle2 cycle2 = new Cycle2(); cycle1.setCycle2(cycle2); cycle2.setCycle1(cycle1); InvocationResponse response = new SocketRMIMarshaller.InvocationResponseImpl(cycle1,null); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); cut.writeInvocationResponse(response, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); InvocationResponse actualResponse = cut.readInvocationResponse(resultType, ioFactory.newCougarObjectInput(new ByteArrayInputStream(outputStream.toByteArray()), protocolVersion)); assertTrue(actualResponse.isSuccess()); assertTrue(actualResponse.getResult() == ((Cycle1)actualResponse.getResult()).getCycle2().getCycle1()); assertNull(actualResponse.getException()); } @Test /** * test the serialisation and deserialisation of response object where the response is created using a delegate * and include test handling for null */ public void testResponseMarshallingWithReturnedDelegateWithNull() throws Exception { ParameterType resultType = new ParameterType(Foo.class,null); Foo foo = new Foo(new FooDelegateImpl("foo")); foo.setBar(null); InvocationResponse response = new SocketRMIMarshaller.InvocationResponseImpl(foo,null); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); cut.writeInvocationResponse(response, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); InvocationResponse actualResponse = cut.readInvocationResponse(resultType, ioFactory.newCougarObjectInput(new ByteArrayInputStream(outputStream.toByteArray()), protocolVersion)); assertTrue(actualResponse.isSuccess()); Object responseObject = removeDelegates(response.getResult()); assertEquals(responseObject, actualResponse.getResult()); assertNull(actualResponse.getException()); } @Test public void testResponseMarshallingWithVoidReturn() throws IOException { InvocationResponse response = new SocketRMIMarshaller.InvocationResponseImpl(null, null); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); cut.writeInvocationResponse(response, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); cougarObjectOutput.close(); InvocationResponse actualResponse = cut.readInvocationResponse(null, ioFactory.newCougarObjectInput(new ByteArrayInputStream(outputStream.toByteArray()), protocolVersion)); assertTrue(actualResponse.isSuccess()); assertNull(actualResponse.getResult()); assertNull(actualResponse.getException()); } @Test public void testResponseMarshallingWithException1() throws IOException { ParameterType resultType = new ParameterType(String.class, null); InvocationResponse response = new SocketRMIMarshaller.InvocationResponseImpl(null, new CougarFrameworkException("All went bad")); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); cut.writeInvocationResponse(response, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); InvocationResponse actualResponse = cut.readInvocationResponse(resultType, ioFactory.newCougarObjectInput(new ByteArrayInputStream(outputStream.toByteArray()), protocolVersion)); assertFalse(actualResponse.isSuccess()); assertNotNull(actualResponse.getException()); assertEquals(response.getException().getFault(), ((CougarException) actualResponse.getException().getCause()).getFault()); } @Test public void testResponseMarshallingWithException2() throws IOException { final String SPURIOUS_EXCEPTION = "Spurious exception"; ParameterType resultType = new ParameterType(String.class, null); InvocationResponse response = new SocketRMIMarshaller.InvocationResponseImpl(null, new CougarFrameworkException("All went bad", new RandomException(SPURIOUS_EXCEPTION))); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); cut.writeInvocationResponse(response, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); InvocationResponse actualResponse = cut.readInvocationResponse(resultType, ioFactory.newCougarObjectInput(new ByteArrayInputStream(outputStream.toByteArray()), protocolVersion)); assertFalse(actualResponse.isSuccess()); RandomException cause = (RandomException)actualResponse.getException().getCause(); Assert.assertEquals(SPURIOUS_EXCEPTION, cause.getMessage()); assertEquals(response.getException().getFault(), actualResponse.getException().getFault()); } @Test public void testResponseMarshallingWithException3() throws IOException { SimpleApplicationException ex = new SimpleApplicationException(ResponseCode.InternalError, "bang"); ParameterType resultType = new ParameterType(String.class, null); InvocationResponse response = new SocketRMIMarshaller.InvocationResponseImpl(null, new CougarFrameworkException("All went bad", ex)); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); cut.writeInvocationResponse(response, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); InvocationResponse actualResponse = cut.readInvocationResponse(resultType, ioFactory.newCougarObjectInput(new ByteArrayInputStream(outputStream.toByteArray()), protocolVersion)); assertFalse(actualResponse.isSuccess()); CougarApplicationException actualCause = (CougarApplicationException)actualResponse.getException().getCause(); Assert.assertEquals(ex.getResponseCode(), actualCause.getResponseCode()); List<String> collatedExpectedFaultList = new ArrayList<String>(); for (String[] group : ex.getApplicationFaultMessages()) { collatedExpectedFaultList.addAll(Arrays.asList(group)); } List<String> collatedActualFaultList = new ArrayList<String>(); for (String[] group : actualCause.getApplicationFaultMessages()) { collatedActualFaultList.addAll(Arrays.asList(group)); } assertTrue(ImmutableMultiset.copyOf(collatedExpectedFaultList).equals(ImmutableMultiset.copyOf(collatedActualFaultList))); Assert.assertEquals(ex.getApplicationFaultNamespace(), actualCause.getApplicationFaultNamespace()); } private IdentityResolver identityResolver = new IdentityResolverImpl(); @Test public void testIdentityChainMarshallsNoIdentities() throws IOException { IdentityChain expected = new IdentityChainImpl(new ArrayList<Identity>()); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); cougarObjectOutput.writeString("127.0.0.1"); // address if (protocolVersion >= CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS) { cougarObjectOutput.writeString(null); } cut.writeIdentity(expected, cougarObjectOutput, identityResolver); cut.writeRequestUUID(new RequestUUIDImpl(), cougarObjectOutput, protocolVersion); cut.writeReceivedTime(new Date(), cougarObjectOutput); cougarObjectOutput.writeBoolean(false); // traceEnabled cut.writeRequestTime(cougarObjectOutput, protocolVersion); cut.writeAdditionalParams(null, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); DehydratedExecutionContext ctx = cut.readExecutionContext(ioFactory.newCougarObjectInput(new ByteArrayInputStream(outputStream.toByteArray()), protocolVersion), "127.0.0.1", new X509Certificate[0], 0, protocolVersion); assertEquals(0, ctx.getIdentityTokens().size()); assertNull(ctx.getIdentity()); } @Test public void testIdentityChainMarshallsOneIdentity() throws IOException { final Identity joe = createIdentity("joeBloggs", "password", "fido123"); IdentityChain expected = new IdentityChainImpl(new ArrayList<Identity>() {{ add(joe); }}); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); cougarObjectOutput.writeString("127.0.0.1"); // address if (protocolVersion >= CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS) { cougarObjectOutput.writeString(null); } cut.writeIdentity(expected, cougarObjectOutput, identityResolver); cut.writeRequestUUID(new RequestUUIDImpl(), cougarObjectOutput, protocolVersion); cut.writeReceivedTime(new Date(), cougarObjectOutput); cougarObjectOutput.writeBoolean(false); // traceEnabled cut.writeRequestTime(cougarObjectOutput, protocolVersion); cut.writeAdditionalParams(null, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); DehydratedExecutionContext ctx = cut.readExecutionContext(ioFactory.newCougarObjectInput(new ByteArrayInputStream(outputStream.toByteArray()), protocolVersion), "127.0.0.1", new X509Certificate[0], 0, protocolVersion); assertEquals(1, ctx.getIdentityTokens().size()); assertNull(ctx.getIdentity()); } @Test public void testIdentityChainMarshallsManyIdentities() throws IOException { final Identity joe = createIdentity("joeBloggs", "password", "fido123"); final Identity sam = createIdentity("samSpade", "password", "topcat999"); IdentityChain expected = new IdentityChainImpl(new ArrayList<Identity>() {{ add(joe); add(sam); }}); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CougarObjectOutput cougarObjectOutput = ioFactory.newCougarObjectOutput(outputStream, protocolVersion); cougarObjectOutput.writeString("127.0.0.1"); // address if (protocolVersion >= CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS) { cougarObjectOutput.writeString(null); } cut.writeIdentity(expected, cougarObjectOutput, identityResolver); cut.writeRequestUUID(new RequestUUIDImpl(), cougarObjectOutput, protocolVersion); cut.writeReceivedTime(new Date(), cougarObjectOutput); cougarObjectOutput.writeBoolean(false); // traceEnabled cut.writeRequestTime(cougarObjectOutput, protocolVersion); cut.writeAdditionalParams(null, cougarObjectOutput, protocolVersion); cougarObjectOutput.flush(); DehydratedExecutionContext ctx = cut.readExecutionContext(ioFactory.newCougarObjectInput(new ByteArrayInputStream(outputStream.toByteArray()), protocolVersion), "127.0.0.1", new X509Certificate[0], 0, protocolVersion); assertEquals(2, ctx.getIdentityTokens().size()); assertNull(ctx.getIdentity()); } @Test public void testWriteArgument() throws IOException{ Parameter[] parameters = new Parameter[2]; parameters[0] = new Parameter("string",new ParameterType(String.class,null),true); parameters[1] = new Parameter("int", new ParameterType(Integer.class,null), false); ByteArrayOutputStream os = new ByteArrayOutputStream(); CougarObjectOutput cos = ioFactory.newCougarObjectOutput(os, protocolVersion); cut.writeArgs(parameters, new Object[] {"abc", 1}, cos); cos.flush(); Object[] args = cut.readArgs(parameters, ioFactory.newCougarObjectInput(new ByteArrayInputStream(os.toByteArray()), protocolVersion)); assertArrayEquals(new Object[]{"abc",1},args); } @Test public void testAdditionalInputArgs() throws IOException { Parameter[] parameters = new Parameter[2]; parameters[0] = new Parameter("int", new ParameterType(Integer.class,null), false); parameters[1] = new Parameter("string",new ParameterType(String.class,null),true); ByteArrayOutputStream os = new ByteArrayOutputStream(); CougarObjectOutput cos = ioFactory.newCougarObjectOutput(os, protocolVersion); cut.writeArgs(parameters, new Object[] {1,"abc"}, cos); cos.flush(); parameters = new Parameter[1]; parameters[0] = new Parameter("string",new ParameterType(String.class,null),true); Object[] args = cut.readArgs(parameters, ioFactory.newCougarObjectInput(new ByteArrayInputStream(os.toByteArray()), protocolVersion)); assertArrayEquals(new Object[]{"abc"}, args); } @Test public void testArgsOutOfOrder() throws IOException { Parameter[] parameters = new Parameter[2]; parameters[0] = new Parameter("int", new ParameterType(Integer.class,null), false); parameters[1] = new Parameter("string",new ParameterType(String.class,null),true); ByteArrayOutputStream os = new ByteArrayOutputStream(); CougarObjectOutput cos = ioFactory.newCougarObjectOutput(os, protocolVersion); cut.writeArgs(parameters, new Object[] {1,"abc"}, cos); cos.flush(); parameters = new Parameter[2]; parameters[0] = new Parameter("string",new ParameterType(String.class,null),true); parameters[1] = new Parameter("int", new ParameterType(Integer.class,null), false); Object[] args = cut.readArgs(parameters, ioFactory.newCougarObjectInput(new ByteArrayInputStream(os.toByteArray()), protocolVersion)); assertArrayEquals(new Object[]{"abc",1}, args); } private Identity createIdentity(String principalName, String credentialName, String credentialValue) { Principal principal = new PrincipalImpl(principalName); Credential credential = new CredentialImpl(credentialName, credentialValue); return new IdentityImpl(principal, credential); } /** * Equals methods in generated idd classes don't handle delegates * @param result * @return * @throws Exception */ private Object removeDelegates(Object result) throws Exception { if (! (result instanceof Transcribable)) { return result; } Transcribable transcribable = (Transcribable) result; final Object[] objects = new Object[transcribable.getParameters().length]; final int[] index = new int[1]; transcribable.transcribe(new TranscriptionOutput(){ @Override public void writeObject(Object obj, Parameter param, boolean client) throws Exception { if (obj == null) { objects[index[0]++] = null; } else if (param.getParameterType().getType() == Type.OBJECT) { objects[index[0]++] = removeDelegates(obj); } else if (param.getParameterType().getType() == Type.LIST) { if (obj.getClass().isArray()) { objects[index[0]] = Array.newInstance(param.getParameterType().getComponentTypes()[0].getImplementationClass(), Array.getLength(obj)); for (int i=0,limit=Array.getLength(obj); i<limit;i++) { Array.set(objects[index[0]], i, removeDelegates(Array.get(obj, i))); } index[0]++; } else { List list = (List) obj; objects[index[0]] = new ArrayList(); for (Object o : list) { ((List)objects[index[0]]).add(removeDelegates(o)); } index[0]++; } } else if (param.getParameterType().getType() == Type.SET) { Set set = (Set) obj; objects[index[0]] = new HashSet(); for (Object o : set) { ((Set)objects[index[0]]).add(removeDelegates(o)); } index[0]++; } else if (param.getParameterType().getType() == Type.MAP) { Map<Object,Object> map = (Map)obj; objects[index[0]] = new HashMap(); for (Entry entry : map.entrySet()) { ((Map)objects[index[0]]).put(removeDelegates(entry.getKey()), removeDelegates(entry.getValue())); } index[0]++; } else { objects[index[0]++] = obj; } }}, TranscribableParams.getAll(), false); Transcribable newObject = (Transcribable) result.getClass().newInstance(); index[0] = 0; newObject.transcribe(new TranscriptionInput() { @Override public <T> T readObject(Parameter param, boolean client) throws Exception { return (T) objects[index[0]++]; }}, TranscribableParams.getAll(), false); return newObject; } }