/*
* Copyright 2014, The Sporting Exchange Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.betfair.cougar.netutil.nio;
import com.betfair.cougar.netutil.nio.message.TLSResult;
import org.apache.mina.common.CloseFuture;
import org.apache.mina.common.IdleStatus;
import org.apache.mina.common.IoFilter;
import org.apache.mina.common.IoFilterChain;
import org.apache.mina.common.IoHandler;
import org.apache.mina.common.IoService;
import org.apache.mina.common.IoServiceConfig;
import org.apache.mina.common.IoSession;
import org.apache.mina.common.IoSessionConfig;
import org.apache.mina.common.TrafficMask;
import org.apache.mina.common.TransportType;
import org.apache.mina.common.WriteFuture;
import org.apache.mina.common.support.DefaultCloseFuture;
import org.apache.mina.common.support.DefaultWriteFuture;
import org.apache.mina.filter.SSLFilter;
import org.junit.Before;
import org.junit.Test;
import java.net.SocketAddress;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.assertFalse;
import static junit.framework.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
*
*/
public class CougarProtocolTest {
private NioLogger logger;
private ICougarProtocol client;
private ICougarProtocol server;
private IoFilter.NextFilter nextFilter;
private PseudoIoSessionMock clientSession;
private PseudoIoSessionMock serverSession;
private SSLFilter sslFilter;
@Before
public void before() {
sslFilter = mock(SSLFilter.class);
nextFilter = mock(IoFilter.NextFilter.class);
}
private void setupProtocol(byte clientVersion, byte serverVersion) {
setupProtocol(clientVersion, null, false, false, serverVersion, null, false, false);
}
private void setupProtocol(byte clientVersion, SSLFilter clientSslFilter, boolean clientRequiresTls,
byte serverVersion, SSLFilter serverSslFilter, boolean serverRequiresTls) {
setupProtocol(clientVersion, clientSslFilter, clientSslFilter!=null, clientRequiresTls,
serverVersion, serverSslFilter, serverSslFilter!=null, serverRequiresTls);
}
private void setupProtocol(byte clientVersion, SSLFilter clientSslFilter, boolean clientSupportsTls, boolean clientRequiresTls,
byte serverVersion, SSLFilter serverSslFilter, boolean serverSupportsTls, boolean serverRequiresTls) {
logger = new NioLogger("ALL");
if (clientVersion == CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC) {
if (clientSupportsTls) {
throw new IllegalArgumentException("Server version doesn't support TLS");
}
client = new CougarProtocol1(false, logger, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, 2000, 5000);
} else if (clientVersion == CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC) {
if (clientSupportsTls) {
throw new IllegalArgumentException("Server version doesn't support TLS");
}
client = new CougarProtocol2(false, logger, 2000, 5000);
} else if (clientVersion == CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS) {
client = new CougarProtocol3(false, logger, 2000, 5000, clientSslFilter, clientSupportsTls, clientRequiresTls, 0);
} else if (clientVersion == CougarProtocol.TRANSPORT_PROTOCOL_VERSION_TIME_CONSTRAINTS) {
client = new CougarProtocol4(false, logger, 2000, 5000, clientSslFilter, clientSupportsTls, clientRequiresTls, 0);
} else if (clientVersion == CougarProtocol.TRANSPORT_PROTOCOL_VERSION_COMPOUND_REQUEST_UUID) {
client = new CougarProtocol5(false, logger, 2000, 5000, clientSslFilter, clientSupportsTls, clientRequiresTls, 0);
} else {
throw new IllegalArgumentException("Unsupported client version: " + clientVersion);
}
if (serverVersion == CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC) {
if (serverSupportsTls) {
throw new IllegalArgumentException("Server version doesn't support TLS");
}
server = new CougarProtocol1(true, logger, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, 2000, 5000);
} else if (serverVersion == CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC) {
if (serverSupportsTls) {
throw new IllegalArgumentException("Server version doesn't support TLS");
}
server = new CougarProtocol2(true, logger, 2000, 5000);
} else if (serverVersion == CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS) {
server = new CougarProtocol3(true, logger, 2000, 5000, serverSslFilter, serverSupportsTls, serverRequiresTls, 0);
} else if (serverVersion == CougarProtocol.TRANSPORT_PROTOCOL_VERSION_TIME_CONSTRAINTS) {
server = new CougarProtocol4(true, logger, 2000, 5000, serverSslFilter, serverSupportsTls, serverRequiresTls, 0);
} else if (serverVersion == CougarProtocol.TRANSPORT_PROTOCOL_VERSION_COMPOUND_REQUEST_UUID) {
server = new CougarProtocol5(true, logger, 2000, 5000, serverSslFilter, serverSupportsTls, serverRequiresTls, 0);
} else {
throw new IllegalArgumentException("Unsupported client version: " + clientVersion);
}
client.setEnabled(true);
server.setEnabled(true);
clientSession = createSession(server);
serverSession = createSession(client);
clientSession.setOtherSession(serverSession);
serverSession.setOtherSession(clientSession);
}
// =================== Version mismatching ========================
@Test
public void versionMismatchVNPlusOneOnly_VN() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_MAX_SUPPORTED, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_MAX_SUPPORTED);
try {
CougarProtocol.setMinClientProtocolVersion((byte)(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_MAX_SUPPORTED+1));
CougarProtocol.setMaxClientProtocolVersion((byte)(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_MAX_SUPPORTED+1));
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
finally {
CougarProtocol.setMinClientProtocolVersion(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_MIN_SUPPORTED);
CougarProtocol.setMaxClientProtocolVersion(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_MAX_SUPPORTED);
}
}
// =================== Successful Handshakes ========================
@Test
public void successfulHandshakePlaintextV3_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, clientSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, serverSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
}
@Test
public void successfulHandshakePlaintextV3_V2() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC, clientSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC, serverSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
}
@Test
public void successfulHandshakePlaintextV3_V1() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, clientSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
}
@Test
public void successfulHandshakeV2_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC, clientSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC, serverSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
}
@Test
public void successfulHandshakeV2_V2() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC, clientSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC, serverSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
}
@Test
public void successfulHandshakeV2_V1() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, clientSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
}
@Test
public void successfulHandshakeV1_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, serverSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
}
@Test
public void successfulHandshakeV1_V2() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, serverSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
}
@Test
public void successfulHandshakeV1_V1() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, serverSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
}
// =================== Server Disabled - Rejected ========================
@Test
public void serverDisabledV3_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS);
server.setEnabled(false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
@Test
public void serverDisabledV3_V2() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC);
server.setEnabled(false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
@Test
public void serverDisabledV3_V1() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC);
server.setEnabled(false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
@Test
public void serverDisabledV2_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS);
server.setEnabled(false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
@Test
public void serverDisabledV2_V2() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC);
server.setEnabled(false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
@Test
public void serverDisabledV2_V1() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC);
server.setEnabled(false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
@Test
public void serverDisabledV1_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS);
server.setEnabled(false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
@Test
public void serverDisabledV1_V2() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC);
server.setEnabled(false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
@Test
public void serverDisabledV1_V1() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC);
server.setEnabled(false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
// ============= TLS - Client Requires, Server too old =============
@Test
public void clientRequiresTlsServerTooOldV3_V2() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, true,
CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC, null, false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
@Test
public void clientRequiresTlsServerTooOldV3_V1() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, true,
CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, null, false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
// ============= TLS - Client too old, Server requires =============
@Test
public void clientTooOldServerRequiresTlsV2_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC, null, false,
CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, true);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
@Test
public void clientTooOldServerRequiresTlsV1_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC, null, false,
CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, true);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
// ============= TLS - Client & server versions sufficient, client requires, server doesn't support =====
@Test
public void clientRequiresTlsServerDoesntSupportTlsV3_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, true,
CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, null, false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
// ============= TLS - Client & server versions sufficient, client requires, server supports =====
@Test
public void clientRequiresTlsServerSupportsTlsV3_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, true,
CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
verify(clientSession.getFilterChain()).addFirst("ssl", sslFilter);
verify(serverSession.getFilterChain()).addFirst("ssl", sslFilter);
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, clientSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(TLSResult.SSL, clientSession.getAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME));
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, serverSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(TLSResult.SSL, serverSession.getAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME));
}
// ============= TLS - Client & server versions sufficient, client requires, server requires =====
@Test
public void clientRequiresTlsServerRequiresTlsV3_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, true,
CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, true);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
verify(clientSession.getFilterChain()).addFirst("ssl", sslFilter);
verify(serverSession.getFilterChain()).addFirst("ssl", sslFilter);
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, clientSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(TLSResult.SSL, clientSession.getAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME));
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, serverSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(TLSResult.SSL, serverSession.getAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME));
}
// ============= TLS - Client & server versions sufficient, client supports, server doesn't =====
@Test
public void clientSupportsTlsServerDoesntSupportTlsV3_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, false,
CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, null, false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
verify(clientSession.getFilterChain(), times(0)).addFirst("ssl", sslFilter);
verify(serverSession.getFilterChain(), times(0)).addFirst("ssl", sslFilter);
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, clientSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(TLSResult.PLAINTEXT, clientSession.getAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME));
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, serverSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(TLSResult.PLAINTEXT, serverSession.getAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME));
}
// ============= TLS - Client & server versions sufficient, client supports, server supports =====
@Test
public void clientSupportsTlsServerSupportsTlsV3_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, false,
CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
verify(clientSession.getFilterChain()).addFirst("ssl", sslFilter);
verify(serverSession.getFilterChain()).addFirst("ssl", sslFilter);
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, clientSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(TLSResult.SSL, clientSession.getAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME));
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, serverSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(TLSResult.SSL, serverSession.getAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME));
}
// ============= TLS - Client & server versions sufficient, client supports, server requires =====
@Test
public void clientSupportsTlsServerRequiresTlsV3_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, false,
CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, true);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
verify(clientSession.getFilterChain()).addFirst("ssl", sslFilter);
verify(serverSession.getFilterChain()).addFirst("ssl", sslFilter);
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, clientSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(TLSResult.SSL, clientSession.getAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME));
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, serverSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(TLSResult.SSL, serverSession.getAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME));
}
// ============= TLS - Client & server versions sufficient, client doesn't support, server doesn't support =====
@Test
public void clientDoesntSupportTlsServerDoesntSupportTlsV3_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, null, false,
CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, null, false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
verify(clientSession.getFilterChain(), times(0)).addFirst("ssl", sslFilter);
verify(serverSession.getFilterChain(), times(0)).addFirst("ssl", sslFilter);
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, clientSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(TLSResult.PLAINTEXT, clientSession.getAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME));
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, serverSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(TLSResult.PLAINTEXT, serverSession.getAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME));
}
// ============= TLS - Client & server versions sufficient, client doesn't support, server supports =====
@Test
public void clientDoesntSupportTlsServerSupportsTlsV3_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, null, false,
CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, false);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertTrue(handshake.successful());
verify(clientSession.getFilterChain(), times(0)).addFirst("ssl", sslFilter);
verify(serverSession.getFilterChain(), times(0)).addFirst("ssl", sslFilter);
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, clientSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(TLSResult.PLAINTEXT, clientSession.getAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME));
assertEquals(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, serverSession.getAttribute(CougarProtocol.PROTOCOL_VERSION_ATTR_NAME));
assertEquals(TLSResult.PLAINTEXT, serverSession.getAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME));
}
// ============= TLS - Client & server versions sufficient, client doesn't support, server requires =====
@Test
public void clientDoesntSupportTlsServerRequiresTlsV3_V3() throws Exception {
setupProtocol(CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, null, false,
CougarProtocol.TRANSPORT_PROTOCOL_VERSION_START_TLS, sslFilter, true);
client.sessionOpened(nextFilter, clientSession);
ClientHandshake handshake = (ClientHandshake) clientSession.getAttribute(ClientHandshake.HANDSHAKE);
assertTrue(handshake.await(5000));
assertFalse(handshake.successful());
}
private PseudoIoSessionMock createSession(ICougarProtocol otherEnd) {
// when(session.write())
return new PseudoIoSessionMock(otherEnd);
}
private class PseudoIoSessionMock implements IoSession {
private IoSession mock = mock(IoSession.class);
private IoFilterChain mockFilterChain = mock(IoFilterChain.class);
private IoFilter.NextFilter nextFilter = mock(IoFilter.NextFilter.class);
private ICougarProtocol otherEnd;
private Map<String, Object> attributes = new HashMap<String, Object>();
private PseudoIoSessionMock otherSession;
private PseudoIoSessionMock(ICougarProtocol otherEnd) {
this.otherEnd = otherEnd;
}
public IoSession getMock() {
return mock;
}
public IoFilter.NextFilter getNextFilter() {
return nextFilter;
}
@Override
public IoService getService() {
return mock.getService();
}
@Override
public IoServiceConfig getServiceConfig() {
return mock.getServiceConfig();
}
@Override
public IoHandler getHandler() {
return mock.getHandler();
}
@Override
public IoSessionConfig getConfig() {
return mock.getConfig();
}
@Override
public IoFilterChain getFilterChain() {
return mockFilterChain;
}
@Override
public WriteFuture write(Object message) {
DefaultWriteFuture ret = new DefaultWriteFuture(this);
try {
otherEnd.messageReceived(nextFilter, otherSession, message);
ret.setWritten(true);
}
catch (Exception e) {
e.printStackTrace();
ret.setWritten(false);
}
// behave like the real thing..
if (attributes.get(SSLFilter.DISABLE_ENCRYPTION_ONCE) != null) {
attributes.remove(SSLFilter.DISABLE_ENCRYPTION_ONCE);
}
return ret;
}
@Override
public CloseFuture close() {
DefaultCloseFuture ret = new DefaultCloseFuture(this);
try {
otherEnd.sessionClosed(nextFilter, otherSession);
ret.setClosed();
} catch (Exception e) {
e.printStackTrace();
ret.setClosed();
}
return ret;
}
@Override
public Object getAttachment() {
return mock.getAttachment();
}
@Override
public Object setAttachment(Object attachment) {
return mock.setAttachment(attachment);
}
@Override
public Object getAttribute(String key) {
return attributes.get(key);
}
@Override
public Object setAttribute(String key, Object value) {
return attributes.put(key, value);
}
@Override
public Object setAttribute(String key) {
return attributes.put(key, Boolean.TRUE);
}
@Override
public Object removeAttribute(String key) {
return attributes.remove(key);
}
@Override
public boolean containsAttribute(String key) {
return attributes.containsKey(key);
}
@Override
public Set<String> getAttributeKeys() {
return attributes.keySet();
}
@Override
public TransportType getTransportType() {
return mock.getTransportType();
}
@Override
public boolean isConnected() {
return mock.isConnected();
}
@Override
public boolean isClosing() {
return mock.isClosing();
}
@Override
public CloseFuture getCloseFuture() {
return mock.getCloseFuture();
}
@Override
public SocketAddress getRemoteAddress() {
return mock.getRemoteAddress();
}
@Override
public SocketAddress getLocalAddress() {
return mock.getLocalAddress();
}
@Override
public SocketAddress getServiceAddress() {
return mock.getServiceAddress();
}
@Override
public int getIdleTime(IdleStatus status) {
return mock.getIdleTime(status);
}
@Override
public long getIdleTimeInMillis(IdleStatus status) {
return mock.getIdleTimeInMillis(status);
}
@Override
public void setIdleTime(IdleStatus status, int idleTime) {
mock.setIdleTime(status, idleTime);
}
@Override
public int getWriteTimeout() {
return mock.getWriteTimeout();
}
@Override
public long getWriteTimeoutInMillis() {
return mock.getWriteTimeoutInMillis();
}
@Override
public void setWriteTimeout(int writeTimeout) {
mock.setWriteTimeout(writeTimeout);
}
@Override
public TrafficMask getTrafficMask() {
return mock.getTrafficMask();
}
@Override
public void setTrafficMask(TrafficMask trafficMask) {
mock.setTrafficMask(trafficMask);
}
@Override
public void suspendRead() {
mock.suspendRead();
}
@Override
public void suspendWrite() {
mock.suspendWrite();
}
@Override
public void resumeRead() {
mock.resumeRead();
}
@Override
public void resumeWrite() {
mock.resumeWrite();
}
@Override
public long getReadBytes() {
return mock.getReadBytes();
}
@Override
public long getWrittenBytes() {
return mock.getWrittenBytes();
}
@Override
public long getReadMessages() {
return mock.getReadMessages();
}
@Override
public long getWrittenMessages() {
return mock.getWrittenMessages();
}
@Override
public long getWrittenWriteRequests() {
return mock.getWrittenWriteRequests();
}
@Override
public int getScheduledWriteRequests() {
return mock.getScheduledWriteRequests();
}
@Override
public int getScheduledWriteBytes() {
return mock.getScheduledWriteBytes();
}
@Override
public long getCreationTime() {
return mock.getCreationTime();
}
@Override
public long getLastIoTime() {
return mock.getLastIoTime();
}
@Override
public long getLastReadTime() {
return mock.getLastReadTime();
}
@Override
public long getLastWriteTime() {
return mock.getLastWriteTime();
}
@Override
public boolean isIdle(IdleStatus status) {
return mock.isIdle(status);
}
@Override
public int getIdleCount(IdleStatus status) {
return mock.getIdleCount(status);
}
@Override
public long getLastIdleTime(IdleStatus status) {
return mock.getLastIdleTime(status);
}
public void setOtherSession(PseudoIoSessionMock otherSession) {
this.otherSession = otherSession;
}
}
}