/*
* Copyright 2014, Simon Matić Langford
* Copyright 2014, The Sporting Exchange Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.betfair.cougar.netutil.nio;
import com.betfair.cougar.core.api.transcription.TranscribableParams;
import com.betfair.cougar.netutil.nio.message.*;
import com.betfair.cougar.util.jmx.JMXControl;
import org.apache.mina.common.*;
import org.apache.mina.filter.SSLFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jmx.export.annotation.ManagedAttribute;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import static com.betfair.cougar.netutil.nio.NioLogger.LoggingLevel.PROTOCOL;
import static com.betfair.cougar.netutil.nio.NioLogger.LoggingLevel.SESSION;
/**
*
*/
public class CougarProtocol4 extends CougarProtocol implements ICougarProtocol {
private static final Logger LOG = LoggerFactory.getLogger(CougarProtocol.class);
private static final KeepAliveMessage KEEP_ALIVE = new KeepAliveMessage();
public static final String PROTOCOL_VERSION_ATTR_NAME = "CougarProtocol.sessionProtocolVersion";
public static final String IS_SERVER_ATTR_NAME = "CougarProtocol.isServer";
public static final String NEGOTIATED_TLS_LEVEL_ATTR_NAME = "CougarProtocol.negotiatedTlsLevel";
public static final String CLIENT_CERTS_ATTR_NAME = "CougarProtocol.clientCertificateChain";
public static final String TSSF_ATTR_NAME = "CougarProtocol.transportSecurityStrengthFactor";
public static final byte TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC = 1;
public static final byte TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC = 2;
public static final byte TRANSPORT_PROTOCOL_VERSION_START_TLS = 3;
public static final byte TRANSPORT_PROTOCOL_VERSION_TIME_CONSTRAINTS = 4;
public static final byte TRANSPORT_PROTOCOL_VERSION_COMPOUND_REQUEST_UUID = 5;
public static final byte TRANSPORT_PROTOCOL_VERSION_MIN_SUPPORTED = TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC;
public static final byte TRANSPORT_PROTOCOL_VERSION_MAX_SUPPORTED = TRANSPORT_PROTOCOL_VERSION_COMPOUND_REQUEST_UUID;
public static final byte TRANSPORT_PROTOCOL_VERSION_UNSUPPORTED = TRANSPORT_PROTOCOL_VERSION_MIN_SUPPORTED - 1;
// these allow tests to force us to particular versions of the protocol, even invalid ones
private static byte maxServerProtocolVersion = TRANSPORT_PROTOCOL_VERSION_MAX_SUPPORTED;
private static byte maxClientProtocolVersion = TRANSPORT_PROTOCOL_VERSION_MAX_SUPPORTED;
private static byte minServerProtocolVersion = TRANSPORT_PROTOCOL_VERSION_MIN_SUPPORTED;
private static byte minClientProtocolVersion = TRANSPORT_PROTOCOL_VERSION_MIN_SUPPORTED;
public static void setMaxServerProtocolVersion(byte maxServerProtocolVersion) {
CougarProtocol4.maxServerProtocolVersion = maxServerProtocolVersion;
}
public static void setMaxClientProtocolVersion(byte maxClientProtocolVersion) {
CougarProtocol4.maxClientProtocolVersion = maxClientProtocolVersion;
}
public static void setMinServerProtocolVersion(byte minServerProtocolVersion) {
CougarProtocol4.minServerProtocolVersion = minServerProtocolVersion;
}
public static void setMinClientProtocolVersion(byte minClientProtocolVersion) {
CougarProtocol4.minClientProtocolVersion = minClientProtocolVersion;
}
private byte[] getServerAcceptableVersions() {
byte[] ret = new byte[(maxServerProtocolVersion - minServerProtocolVersion) + 1];
int ind = 0;
for (byte i = maxServerProtocolVersion; i >= minServerProtocolVersion; i--) {
ret[ind++] = i;
}
return ret;
}
private byte[] getClientAcceptableVersions() {
byte[] ret = new byte[(maxClientProtocolVersion - minClientProtocolVersion) + 1];
int ind = 0;
for (byte i = minClientProtocolVersion; i <= maxClientProtocolVersion; i++) {
ret[ind++] = i;
}
return ret;
}
private static Set<TranscribableParams>[] transcribableParamsByProtocolVersion;
static {
Set<TranscribableParams>[] map = new Set[TRANSPORT_PROTOCOL_VERSION_MAX_SUPPORTED+1];
map[TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC] = Collections.unmodifiableSet(EnumSet.noneOf(TranscribableParams.class));
map[TRANSPORT_PROTOCOL_VERSION_BIDIRECTION_RPC] = Collections.unmodifiableSet(EnumSet.noneOf(TranscribableParams.class));
map[TRANSPORT_PROTOCOL_VERSION_START_TLS] = Collections.unmodifiableSet(EnumSet.of(TranscribableParams.EnumsWrittenAsStrings, TranscribableParams.MajorOnlyPackageNaming));
map[TRANSPORT_PROTOCOL_VERSION_TIME_CONSTRAINTS] = Collections.unmodifiableSet(EnumSet.of(TranscribableParams.EnumsWrittenAsStrings, TranscribableParams.MajorOnlyPackageNaming));
map[TRANSPORT_PROTOCOL_VERSION_COMPOUND_REQUEST_UUID] = Collections.unmodifiableSet(EnumSet.of(TranscribableParams.EnumsWrittenAsStrings, TranscribableParams.MajorOnlyPackageNaming));
transcribableParamsByProtocolVersion = map;
}
public static Set<TranscribableParams> getTranscribableParamSet(IoSession session) {
return getTranscribableParamSet(getProtocolVersion(session));
}
public static Set<TranscribableParams> getTranscribableParamSet(byte protocolVersion) {
return transcribableParamsByProtocolVersion[protocolVersion];
}
private final NioLogger nioLogger;
private boolean isServer;
private volatile boolean isEnabled = false;
private final int interval;
private final int timeout;
private final AtomicLong heartbeatsMissed = new AtomicLong();
private final AtomicLong heartbeatsSent = new AtomicLong();
private final AtomicLong sessionsCreated = new AtomicLong();
private String lastSessionFrom = null;
private final SSLFilter sslFilter;
private final boolean supportsTls;
private final boolean requiresTls;
private final long rpcTimeoutMillis;
public static CougarProtocol getClientInstance(NioLogger nioLogger, int keepAliveInterval, int keepAliveTimeout, SSLFilter sslFilter, boolean supportsTls, boolean requiresTls, long rpcTimeoutMillis) {
return new CougarProtocol(false, nioLogger, keepAliveInterval, keepAliveTimeout, sslFilter, supportsTls, requiresTls, rpcTimeoutMillis);
}
public static CougarProtocol getServerInstance(NioLogger nioLogger, int keepAliveInterval, int keepAliveTimeout, SSLFilter sslFilter, boolean supportsTls, boolean requiresTls) {
return new CougarProtocol(true, nioLogger, keepAliveInterval, keepAliveTimeout, sslFilter, supportsTls, requiresTls, 0);
}
protected CougarProtocol4(boolean server, NioLogger nioLogger, int keepAliveInterval, int keepAliveTimeout, SSLFilter sslFilter, boolean supportsTls, boolean requiresTls, long rpcTimeoutMillis) {
super(server, nioLogger, keepAliveInterval, keepAliveTimeout, sslFilter, supportsTls, requiresTls, rpcTimeoutMillis);
this.isServer = server;
this.nioLogger = nioLogger;
this.interval = keepAliveInterval;
this.timeout = keepAliveTimeout;
this.sslFilter = sslFilter;
this.supportsTls = supportsTls;
this.requiresTls = requiresTls;
this.rpcTimeoutMillis = rpcTimeoutMillis;
export(nioLogger.getJmxControl());
}
public void closeSession(final IoSession ioSession) {
closeSession(ioSession, false);
}
public void closeSession(final IoSession ioSession, boolean blockUntilComplete) {
WriteFuture future = ioSession.write(new DisconnectMessage());
final AtomicReference<CloseFuture> closeFuture = new AtomicReference<CloseFuture>();
final CountDownLatch latch = new CountDownLatch(1);
future.addListener(new IoFutureListener() {
@Override
public void operationComplete(IoFuture future) {
nioLogger.log(NioLogger.LoggingLevel.SESSION, ioSession, "CougarProtocol - Closing session after disconnection");
closeFuture.set(future.getSession().close());
latch.countDown();
}
});
if (blockUntilComplete) {
try {
future.join();
latch.await();
closeFuture.get().join();
}
catch (InterruptedException ie) {
// ignore, this shouldn't happen, and tends only to be used for tests
}
}
}
public void suspendSession(final IoSession ioSession) {
final Byte protocolVersion = (Byte) ioSession.getAttribute(PROTOCOL_VERSION_ATTR_NAME);
if (protocolVersion == null || protocolVersion.equals(TRANSPORT_PROTOCOL_VERSION_CLIENT_ONLY_RPC)) {
return; // We don't need to do this for clients using older version, as they don't understand this message
}
WriteFuture future = ioSession.write(new SuspendMessage());
future.addListener(new IoFutureListener() {
@Override
public void operationComplete(IoFuture future) {
nioLogger.log(NioLogger.LoggingLevel.SESSION, ioSession, "CougarProtocol - Suspended session");
}
});
}
@Override
public void sessionOpened(NextFilter nextFilter, IoSession session) throws Exception {
if (!isServer) {
ClientHandshake clientHandshake = new ClientHandshake();
session.setAttribute(ClientHandshake.HANDSHAKE, clientHandshake);
session.write(new ConnectMessage(getClientAcceptableVersions()));
}
super.sessionOpened(nextFilter, session);
}
@Override
public void sessionIdle(NextFilter nextFilter, IoSession session, IdleStatus status) throws Exception {
try {
if (status == IdleStatus.WRITER_IDLE) {
nioLogger.log(PROTOCOL, session, "CougarProtocolCodecFilter: sending KEEP_ALIVE");
session.write(KEEP_ALIVE);
heartbeatsSent.incrementAndGet();
} else {
nioLogger.log(PROTOCOL, session, "CougarProtocolCodecFilter: KEEP_ALIVE timeout closing session");
session.close();
heartbeatsMissed.incrementAndGet();
}
} finally {
nextFilter.sessionIdle(session, status);
}
}
@Override
public void sessionCreated(NextFilter nextFilter, IoSession session) throws Exception {
session.setIdleTime(IdleStatus.READER_IDLE, timeout);
session.setIdleTime(IdleStatus.WRITER_IDLE, interval);
nextFilter.sessionCreated(session);
nioLogger.log(SESSION, session, "CougarProtocolCodecFilter: Created session at %s from %s", session.getCreationTime(), session.getRemoteAddress());
sessionsCreated.incrementAndGet();
lastSessionFrom = session.getRemoteAddress().toString();
}
public static byte getProtocolVersion(IoSession session) {
Byte b = (Byte) session.getAttribute(PROTOCOL_VERSION_ATTR_NAME);
if (b == null) {
throw new IllegalStateException("Protocol version requested for session before determined");
}
return b;
}
@Override
public void messageReceived(NextFilter nextFilter, IoSession session, Object message) throws Exception {
if (message instanceof ProtocolMessage) {
ProtocolMessage protocolMessage = (ProtocolMessage) message;
switch (protocolMessage.getProtocolMessageType()) {
case CONNECT:
// server side - request to connect from client
if (isEnabled()) {
ConnectMessage connectMessage = (ConnectMessage) protocolMessage;
//As a server, ensure that we support a version the client also supports
byte protocolVersionToUse = TRANSPORT_PROTOCOL_VERSION_UNSUPPORTED;
for (byte testVersion = maxServerProtocolVersion; testVersion >= minServerProtocolVersion; testVersion--) {
if (Arrays.binarySearch(connectMessage.getApplicationVersions(), testVersion) >= 0) {
protocolVersionToUse = testVersion;
break;
}
}
if (protocolVersionToUse >= minServerProtocolVersion) {
// older versions of the protocol don't support TLS, so if we require it, then we have to stop here
if (protocolVersionToUse < TRANSPORT_PROTOCOL_VERSION_START_TLS && requiresTls) {
nioLogger.log(PROTOCOL, session, "CougarProtocol: REJECTing connection request with version %s since we require TLS, which is not supported on this version", protocolVersionToUse);
session.write(new RejectMessage(RejectMessageReason.INCOMPATIBLE_VERSION, getServerAcceptableVersions()));
session.close();
}
else {
nioLogger.log(PROTOCOL, session, "CougarProtocol: ACCEPTing connection request with version %s", protocolVersionToUse);
session.setAttribute(PROTOCOL_VERSION_ATTR_NAME, protocolVersionToUse);
session.setAttribute(IS_SERVER_ATTR_NAME, true);
// this is used for all writes to the session after the initial handshaking
session.setAttribute(RequestResponseManager.SESSION_KEY, new RequestResponseManagerImpl(session, nioLogger, rpcTimeoutMillis));
session.write(new AcceptMessage(protocolVersionToUse));
}
} else {
//we don't speak your language. goodbye
nioLogger.log(PROTOCOL, session, "CougarProtocol: REJECTing connection request with versions %s", getAsString(connectMessage.getApplicationVersions()));
LOG.info("REJECTing connection request from session " + session.getRemoteAddress() + " with versions " + getAsString(connectMessage.getApplicationVersions()));
session.write(new RejectMessage(RejectMessageReason.INCOMPATIBLE_VERSION, getServerAcceptableVersions()));
session.close();
}
} else {
nioLogger.log(PROTOCOL, session, "REJECTing connection request from session %s as service unavailable", session.getReadMessages());
LOG.info("REJECTing connection request from session " + session.getReadMessages() + " as service unavailable");
session.write(new RejectMessage(RejectMessageReason.SERVER_UNAVAILABLE, getServerAcceptableVersions()));
session.close();
}
break;
case ACCEPT:
//Client Side - server has accepted our connection request
AcceptMessage acceptMessage = (AcceptMessage) protocolMessage;
if (acceptMessage.getAcceptedVersion() < minClientProtocolVersion || acceptMessage.getAcceptedVersion() > maxClientProtocolVersion) {
nioLogger.log(PROTOCOL, session, "Protocol version mismatch - client version is %s, server has accepted %s", maxClientProtocolVersion, acceptMessage.getAcceptedVersion());
session.close();
throw new IllegalStateException("Protocol version mismatch - client version is " + maxClientProtocolVersion + ", server has accepted " + acceptMessage.getAcceptedVersion());
}
nioLogger.log(PROTOCOL, session, "CougarProtocol: ACCEPT received for with version %s", acceptMessage.getAcceptedVersion());
session.setAttribute(IS_SERVER_ATTR_NAME, false);
session.setAttribute(PROTOCOL_VERSION_ATTR_NAME, acceptMessage.getAcceptedVersion());
session.setAttribute(RequestResponseManager.SESSION_KEY, new RequestResponseManagerImpl(session, nioLogger, rpcTimeoutMillis));
// if we're running version 3 or later then send our TLS request, otherwise we're done handshaking
if (acceptMessage.getAcceptedVersion() >= TRANSPORT_PROTOCOL_VERSION_START_TLS) {
TLSRequirement requirement;
if (requiresTls) {
requirement = TLSRequirement.REQUIRED;
}
else if (supportsTls) {
requirement = TLSRequirement.SUPPORTED;
}
else {
requirement = TLSRequirement.NONE;
}
nioLogger.log(PROTOCOL, session, "CougarProtocol: Server version supports TLS, sending requirement of %s", requirement);
session.write(new StartTLSRequestMessage(requirement));
}
// if we had to have tls, but the server is running an old version then we need to disconnect
else if (requiresTls) {
nioLogger.log(PROTOCOL, session, "CougarProtocol: Server version doesn't support TLS, sending DISCONNECT");
session.write(new DisconnectMessage());
ClientHandshake handshake = (ClientHandshake) session.getAttribute(ClientHandshake.HANDSHAKE);
if (handshake != null) {
handshake.reject();
}
session.close();
}
else {
ClientHandshake handshake = (ClientHandshake) session.getAttribute(ClientHandshake.HANDSHAKE);
if (handshake != null) {
handshake.accept();
}
}
break;
case REJECT:
//Client Side - server has said foxtrot oscar
RejectMessage rejectMessage = (RejectMessage) protocolMessage;
nioLogger.log(PROTOCOL, session, "CougarProtocol: REJECT received: versions accepted are %s", getAsString(rejectMessage.getAcceptableVersions()));
ClientHandshake handshake2 = (ClientHandshake) session.getAttribute(ClientHandshake.HANDSHAKE);
if (handshake2 != null) {
handshake2.reject();
}
session.close();
break;
case START_TLS_REQUEST:
// server side - client has sent it's tls requirements
StartTLSRequestMessage tlsRequestMessage = (StartTLSRequestMessage) protocolMessage;
nioLogger.log(PROTOCOL, session, "CougarProtocol: START_TLS_REQUEST - requirement is %s", tlsRequestMessage.getRequirement());
TLSResult result;
switch (tlsRequestMessage.getRequirement()) {
case NONE:
if (requiresTls) {
result = TLSResult.FAILED_NEGOTIATION;
}
else {
result = TLSResult.PLAINTEXT;
}
break;
case SUPPORTED:
if (supportsTls) {
result = TLSResult.SSL;
}
else {
result = TLSResult.PLAINTEXT;
}
break;
case REQUIRED:
if (supportsTls) {
result = TLSResult.SSL;
}
else {
result = TLSResult.FAILED_NEGOTIATION;
}
break;
default:
throw new IllegalStateException("Unsupported TLS requirement received " + tlsRequestMessage.getRequirement());
}
StartTLSResponseMessage tlsResponseMessage = new StartTLSResponseMessage(result);
if (result != TLSResult.FAILED_NEGOTIATION) {
session.setAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME, result);
nioLogger.log(PROTOCOL, session, "CougarProtocol: START_TLS_REQUEST - successfully negotiated %s comms", result);
if (result == TLSResult.SSL) {
// Insert SSLFilter to get ready for handshaking
session.getFilterChain().addFirst("ssl", sslFilter);
// Disable encryption temporarilly.
// This attribute will be removed by SSLFilter
// inside the Session.write() call below.
session.setAttribute(SSLFilter.DISABLE_ENCRYPTION_ONCE, Boolean.TRUE);
}
}
session.write(tlsResponseMessage);
if (result == TLSResult.SSL) {
// Now DISABLE_ENCRYPTION_ONCE attribute is cleared.
assert session.getAttribute(SSLFilter.DISABLE_ENCRYPTION_ONCE) == null;
}
else if (result == TLSResult.FAILED_NEGOTIATION) {
nioLogger.log(PROTOCOL, session, "CougarProtocol: START_TLS_REQUEST - negotiation failed, closing session");
session.close();
}
break;
case START_TLS_RESPONSE:
// client side - server has determined our TLS settings for this connection
StartTLSResponseMessage responseMessage = (StartTLSResponseMessage) protocolMessage;
if (responseMessage.getResult() == TLSResult.FAILED_NEGOTIATION) {
nioLogger.log(PROTOCOL, session, "CougarProtocol: START_TLS_RESPONSE - FAILED_NEGOTIATION received");
ClientHandshake handshake3 = (ClientHandshake) session.getAttribute(ClientHandshake.HANDSHAKE);
if (handshake3 != null) {
handshake3.reject();
}
session.close();
}
else {
session.setAttribute(CougarProtocol.NEGOTIATED_TLS_LEVEL_ATTR_NAME, responseMessage.getResult());
nioLogger.log(PROTOCOL, session, "CougarProtocol: Starting %s comms following successful TLS negotiation", responseMessage.getResult());
if (responseMessage.getResult() == TLSResult.SSL) {
// Insert SSLFilter to get ready for handshaking
session.getFilterChain().addFirst("ssl", sslFilter);
}
// finish handshaking
ClientHandshake handshake = (ClientHandshake) session.getAttribute(ClientHandshake.HANDSHAKE);
if (handshake != null) {
handshake.accept();
}
}
break;
case KEEP_ALIVE:
//Both sides keep alive received, which is ignored
nioLogger.log(PROTOCOL, session, "CougarProtocol: KEEP_ALIVE received");
break;
case DISCONNECT:
//Client Side - server doesn't love us anymore
session.setAttribute(ProtocolMessage.ProtocolMessageType.DISCONNECT.name());
nioLogger.log(PROTOCOL, session, "CougarProtocol: DISCONNECT received");
session.close();
break;
case SUSPEND:
//Client Side - this session is about to be closed
session.setAttribute(ProtocolMessage.ProtocolMessageType.SUSPEND.name());
nioLogger.log(PROTOCOL, session, "CougarProtocol: SUSPEND received");
break;
case MESSAGE_REQUEST:
case MESSAGE_RESPONSE:
case EVENT:
super.messageReceived(nextFilter, session, message);
break;
default:
LOG.error("Unknown message type " + protocolMessage.getProtocolMessageType() + " - Ignoring");
}
}
}
private String getAsString(byte[] versions) {
StringBuilder sb = new StringBuilder("{");
boolean first = true;
for (byte b : versions) {
if (first) {
first = false;
} else {
sb.append(",");
}
sb.append(b);
}
sb.append("}");
return sb.toString();
}
@ManagedAttribute
public void setEnabled(boolean healthy) {
this.isEnabled = healthy;
}
@ManagedAttribute
public boolean isEnabled() {
return this.isEnabled;
}
/**
* Exports this service as an MBean, if the JMXControl is available
*/
@Override
public void export(JMXControl jmxControl) {
if (jmxControl != null) {
jmxControl.registerMBean("CoUGAR.socket.transport:name=wireProtocol", this);
}
}
@ManagedAttribute
public int getInterval() {
return interval;
}
@ManagedAttribute
public int getTimeout() {
return timeout;
}
@ManagedAttribute
public long getHeartbeatsMissed() {
return heartbeatsMissed.get();
}
@ManagedAttribute
public long getHeartbeatsSent() {
return heartbeatsSent.get();
}
@ManagedAttribute
public long getSessionsCreated() {
return sessionsCreated.get();
}
@ManagedAttribute
public String getLastSessionFrom() {
return lastSessionFrom;
}
@ManagedAttribute
public boolean isSupportsTls() {
return supportsTls;
}
@ManagedAttribute
public boolean isRequiresTls() {
return requiresTls;
}
// for testing
SSLFilter getSslFilter() {
return sslFilter;
}
}