/*
* Copyright 2014, The Sporting Exchange Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.betfair.cougar.client.socket;
import com.betfair.cougar.netutil.nio.message.ProtocolMessage;
import com.betfair.cougar.util.JMXReportingThreadPoolExecutor;
import org.apache.mina.common.IoSession;
import org.junit.Before;
import org.junit.Test;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.*;
/**
* Session factory tests
*/
public class IoSessionFactoryTest {
private IoSessionFactory sessionFactory;
private IoSession connectedSession;
@Test
public void testGetSessionAlwaysReturnsConnectedSessions() throws Exception {
for (int i = 0; i < 10; i++) {
assertTrue(connectedSession.equals(sessionFactory.getSession()));
}
}
@Test
public void testOpenSession() throws Exception {
final InetSocketAddress address1 = new InetSocketAddress("host1", 9003);
final InetSocketAddress address2 = new InetSocketAddress("host2", 9003);
sessionFactory.openSession(address1);
sessionFactory.openSession(address2);
Field pendingConnections = IoSessionFactory.class.getDeclaredField("pendingConnections");
pendingConnections.setAccessible(true);
Map<SocketAddress, Object> map = (Map<SocketAddress, Object>) pendingConnections.get(sessionFactory);
assertEquals(2, map.size());
assertTrue(map.keySet().contains(address1));
assertTrue(map.keySet().contains(address2));
sessionFactory.openSession(address1);
sessionFactory.openSession(address2);
map = (Map<SocketAddress, Object>) pendingConnections.get(sessionFactory);
assertEquals(2, map.size());
assertTrue(map.keySet().contains(address1));
assertTrue(map.keySet().contains(address2));
}
@Test
public void testCloseSession() throws Exception {
final InetSocketAddress address1 = new InetSocketAddress("host1", 9003);
final InetSocketAddress address2 = new InetSocketAddress("host2", 9003);
sessionFactory.openSession(address1);
sessionFactory.openSession(address2);
sessionFactory.closeSession(connectedSession.getRemoteAddress(), false);
sessionFactory.closeSession(address1, false);
sessionFactory.closeSession(address2, false);
verify(sessionFactory).close(connectedSession, false);
}
@Before
public void setup() throws Exception {
sessionFactory = mock(IoSessionFactory.class);
Field sessionsField = IoSessionFactory.class.getDeclaredField("sessions");
sessionsField.setAccessible(true);
Map<SocketAddress, IoSession> sessions = new HashMap<SocketAddress, IoSession>();
connectedSession = getConnectedSession();
sessions.put(connectedSession.getRemoteAddress(), connectedSession);
final IoSession notConnectedSession = getNotConnectedSession();
sessions.put(notConnectedSession.getRemoteAddress(), notConnectedSession);
final IoSession closingSession = getClosingSession();
sessions.put(closingSession.getRemoteAddress(), closingSession);
final IoSession suspendedSession = getSuspendedSession();
sessions.put(suspendedSession.getRemoteAddress(), suspendedSession);
final IoSession disconnectedSession = getDisconnectedSession();
sessions.put(disconnectedSession.getRemoteAddress(), disconnectedSession);
sessionsField.set(sessionFactory, sessions);
when(sessionFactory.getSession()).thenCallRealMethod();
when(sessionFactory.isAvailable(any(IoSession.class))).thenCallRealMethod();
Field lockField = IoSessionFactory.class.getDeclaredField("lock");
lockField.setAccessible(true);
lockField.set(sessionFactory, new Object());
Field pendingConnections = IoSessionFactory.class.getDeclaredField("pendingConnections");
pendingConnections.setAccessible(true);
pendingConnections.set(sessionFactory, new HashMap());
Field executor = IoSessionFactory.class.getDeclaredField("reconnectExecutor");
executor.setAccessible(true);
ExecutorService mockExecutor = mock(JMXReportingThreadPoolExecutor.class);
when(mockExecutor.submit(any(Runnable.class))).thenReturn(null);
executor.set(sessionFactory, mockExecutor);
doCallRealMethod().when(sessionFactory).openSession(any(SocketAddress.class));
doCallRealMethod().when(sessionFactory).closeSession(any(SocketAddress.class), anyBoolean());
}
private IoSession getConnectedSession() {
return getSession(1, true, false, false, false);
}
private IoSession getNotConnectedSession() {
return getSession(2, false, false, false, false);
}
private IoSession getClosingSession() {
return getSession(3, true, true, false, false);
}
private IoSession getSuspendedSession() {
return getSession(4, true, false, true, false);
}
private IoSession getDisconnectedSession() {
return getSession(5, true, false, false, true);
}
private IoSession getSession(int id, boolean isConnected, boolean isClosing, boolean isSuspended, boolean isDisconnected) {
final IoSession ioSession = mock(IoSession.class);
when(ioSession.isConnected()).thenReturn(isConnected);
when(ioSession.isClosing()).thenReturn(isClosing);
when(ioSession.containsAttribute(ProtocolMessage.ProtocolMessageType.SUSPEND.name())).thenReturn(isSuspended);
when(ioSession.containsAttribute(ProtocolMessage.ProtocolMessageType.DISCONNECT.name())).thenReturn(isDisconnected);
when(ioSession.getRemoteAddress()).thenReturn(new InetSocketAddress("1.1.1." + id, 9003));
return ioSession;
}
}