package org.sdnplatform.sync.internal.rpc;
import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Map.Entry;
import net.floodlightcontroller.core.annotations.LogMessageCategory;
import net.floodlightcontroller.core.annotations.LogMessageDoc;
import net.floodlightcontroller.debugcounter.IDebugCounter;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelStateEvent;
import org.jboss.netty.channel.MessageEvent;
import org.sdnplatform.sync.IClosableIterator;
import org.sdnplatform.sync.IStoreClient;
import org.sdnplatform.sync.IVersion;
import org.sdnplatform.sync.Versioned;
import org.sdnplatform.sync.ISyncService.Scope;
import org.sdnplatform.sync.error.AuthException;
import org.sdnplatform.sync.error.ObsoleteVersionException;
import org.sdnplatform.sync.error.SyncException;
import org.sdnplatform.sync.internal.Cursor;
import org.sdnplatform.sync.internal.SyncManager;
import org.sdnplatform.sync.internal.config.AuthScheme;
import org.sdnplatform.sync.internal.config.ClusterConfig;
import org.sdnplatform.sync.internal.config.Node;
import org.sdnplatform.sync.internal.config.SyncStoreCCProvider;
import org.sdnplatform.sync.internal.rpc.RPCService.NodeMessage;
import org.sdnplatform.sync.internal.store.IStorageEngine;
import org.sdnplatform.sync.internal.util.ByteArray;
import org.sdnplatform.sync.internal.util.CryptoUtil;
import org.sdnplatform.sync.internal.version.VectorClock;
import org.sdnplatform.sync.thrift.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Channel handler for the RPC service
* @author readams
*/
@LogMessageCategory("State Synchronization")
public class RPCChannelHandler extends AbstractRPCChannelHandler {
protected static final Logger logger =
LoggerFactory.getLogger(RPCChannelHandler.class);
protected SyncManager syncManager;
protected RPCService rpcService;
protected Node remoteNode;
protected boolean isClientConnection = false;
public RPCChannelHandler(SyncManager syncManager,
RPCService rpcService) {
super();
this.syncManager = syncManager;
this.rpcService = rpcService;
}
// ****************************
// IdleStateAwareChannelHandler
// ****************************
@Override
public void channelOpen(ChannelHandlerContext ctx,
ChannelStateEvent e) throws Exception {
rpcService.cg.add(ctx.getChannel());
}
@Override
public void channelDisconnected(ChannelHandlerContext ctx,
ChannelStateEvent e) throws Exception {
if (remoteNode != null) {
rpcService.disconnectNode(remoteNode.getNodeId());
}
}
// ******************************************
// AbstractRPCChannelHandler message handlers
// ******************************************
@Override
public void messageReceived(ChannelHandlerContext ctx,
MessageEvent e) throws Exception {
super.messageReceived(ctx, e);
}
@Override
@LogMessageDoc(level="ERROR",
message="[{id}->{id}] Attempted connection from unrecognized " +
"floodlight node {id}; disconnecting",
explanation="A unknown node connected. This can happen " +
"transiently if new nodes join the cluster.",
recommendation="If the problem persists, verify your cluster" +
"configuration and that you don't have unauthorized agents " +
"in your network.")
protected void handleHello(HelloMessage hello, Channel channel) {
if (!hello.isSetNodeId()) {
// this is a client connection. Don't set this up as a node
// connection
isClientConnection = true;
return;
}
remoteNode = syncManager.getClusterConfig().getNode(hello.getNodeId());
if (remoteNode == null) {
logger.error("[{}->{}] Attempted connection from unrecognized " +
"floodlight node {}; disconnecting",
new Object[]{getLocalNodeIdString(),
getRemoteNodeIdString(),
hello.getNodeId()});
channel.close();
return;
}
rpcService.nodeConnected(remoteNode.getNodeId(), channel);
FullSyncRequestMessage srm = new FullSyncRequestMessage();
AsyncMessageHeader header = new AsyncMessageHeader();
header.setTransactionId(getTransactionId());
srm.setHeader(header);
SyncMessage bsm = new SyncMessage(MessageType.FULL_SYNC_REQUEST);
channel.write(bsm);
// XXX - TODO - if last connection was longer ago than the tombstone
// timeout, then we need to do a complete flush and reload of our
// state. This is complex though since this applies across entire
// partitions and not just single nodes. We'd need to identify the
// partition and nuke the smaller half (or lower priority in the case
// of an even split). Downstream listeners would need to be able to
// handle a state nuke as well. A simple way to nuke would be to ensure
// floodlight is restarted in the smaller partition.
}
@Override
protected void handleGetRequest(GetRequestMessage request,
Channel channel) {
String storeName = request.getStoreName();
try {
IStorageEngine<ByteArray, byte[]> store =
syncManager.getRawStore(storeName);
GetResponseMessage m = new GetResponseMessage();
AsyncMessageHeader header = new AsyncMessageHeader();
header.setTransactionId(request.getHeader().getTransactionId());
m.setHeader(header);
List<Versioned<byte[]>> values =
store.get(new ByteArray(request.getKey()));
for (Versioned<byte[]> value : values) {
m.addToValues(TProtocolUtil.getTVersionedValue(value));
}
SyncMessage bsm = new SyncMessage(MessageType.GET_RESPONSE);
bsm.setGetResponse(m);
channel.write(bsm);
} catch (Exception e) {
channel.write(getError(request.getHeader().getTransactionId(), e,
MessageType.GET_REQUEST));
}
}
@Override
protected void handlePutRequest(PutRequestMessage request,
Channel channel) {
String storeName = request.getStoreName();
try {
IStorageEngine<ByteArray, byte[]> store =
syncManager.getRawStore(storeName);
ByteArray key = new ByteArray(request.getKey());
Versioned<byte[]> value = null;
if (request.isSetVersionedValue()) {
value = TProtocolUtil.
getVersionedValued(request.getVersionedValue());
value.increment(syncManager.getLocalNodeId(),
System.currentTimeMillis());
} else if (request.isSetValue()) {
byte[] rvalue = request.getValue();
List<IVersion> versions = store.getVersions(key);
VectorClock newclock = new VectorClock();
for (IVersion v : versions) {
newclock = newclock.merge((VectorClock)v);
}
newclock = newclock.incremented(syncManager.getLocalNodeId(),
System.currentTimeMillis());
value = Versioned.value(rvalue, newclock);
} else {
throw new SyncException("No value specified for put");
}
store.put(key, value);
PutResponseMessage m = new PutResponseMessage();
AsyncMessageHeader header = new AsyncMessageHeader();
header.setTransactionId(request.getHeader().getTransactionId());
m.setHeader(header);
SyncMessage bsm = new SyncMessage(MessageType.PUT_RESPONSE);
bsm.setPutResponse(m);
channel.write(bsm);
} catch (Exception e) {
channel.write(getError(request.getHeader().getTransactionId(), e,
MessageType.PUT_REQUEST));
}
}
@Override
protected void handleDeleteRequest(DeleteRequestMessage request,
Channel channel) {
try {
String storeName = request.getStoreName();
IStorageEngine<ByteArray, byte[]> store =
syncManager.getRawStore(storeName);
ByteArray key = new ByteArray(request.getKey());
VectorClock newclock;
if (request.isSetVersion()) {
newclock = TProtocolUtil.getVersion(request.getVersion());
} else {
newclock = new VectorClock();
List<IVersion> versions = store.getVersions(key);
for (IVersion v : versions) {
newclock = newclock.merge((VectorClock)v);
}
}
newclock =
newclock.incremented(rpcService.syncManager.getLocalNodeId(),
System.currentTimeMillis());
Versioned<byte[]> value = Versioned.value(null, newclock);
store.put(key, value);
DeleteResponseMessage m = new DeleteResponseMessage();
AsyncMessageHeader header = new AsyncMessageHeader();
header.setTransactionId(request.getHeader().getTransactionId());
m.setHeader(header);
SyncMessage bsm =
new SyncMessage(MessageType.DELETE_RESPONSE);
bsm.setDeleteResponse(m);
channel.write(bsm);
} catch (Exception e) {
channel.write(getError(request.getHeader().getTransactionId(), e,
MessageType.DELETE_REQUEST));
}
}
@Override
protected void handleSyncValue(SyncValueMessage request,
Channel channel) {
if (request.isSetResponseTo())
rpcService.messageAcked(MessageType.SYNC_REQUEST,
getRemoteNodeId());
try {
if (logger.isTraceEnabled()) {
logger.trace("[{}->{}] Got syncvalue {}",
new Object[]{getLocalNodeIdString(),
getRemoteNodeIdString(),
request});
}
Scope scope = TProtocolUtil.getScope(request.getStore().getScope());
for (KeyedValues kv : request.getValues()) {
Iterable<VersionedValue> tvvi = kv.getValues();
Iterable<Versioned<byte[]>> vs = new TVersionedValueIterable(tvvi);
syncManager.writeSyncValue(request.getStore().getStoreName(),
scope,
request.getStore().isPersist(),
kv.getKey(), vs);
}
SyncValueResponseMessage m = new SyncValueResponseMessage();
m.setCount(request.getValuesSize());
AsyncMessageHeader header = new AsyncMessageHeader();
header.setTransactionId(request.getHeader().getTransactionId());
m.setHeader(header);
SyncMessage bsm =
new SyncMessage(MessageType.SYNC_VALUE_RESPONSE);
bsm.setSyncValueResponse(m);
updateCounter(SyncManager.counterReceivedValues,
request.getValuesSize());
channel.write(bsm);
} catch (Exception e) {
channel.write(getError(request.getHeader().getTransactionId(), e,
MessageType.SYNC_VALUE));
}
}
@Override
protected void handleSyncValueResponse(SyncValueResponseMessage message,
Channel channel) {
rpcService.messageAcked(MessageType.SYNC_VALUE, getRemoteNodeId());
}
@Override
protected void handleSyncOffer(SyncOfferMessage request,
Channel channel) {
try {
String storeName = request.getStore().getStoreName();
SyncRequestMessage srm = new SyncRequestMessage();
AsyncMessageHeader header = new AsyncMessageHeader();
header.setTransactionId(request.getHeader().getTransactionId());
srm.setHeader(header);
srm.setStore(request.getStore());
for (KeyedVersions kv : request.getVersions()) {
Iterable<org.sdnplatform.sync.thrift.VectorClock> tvci =
kv.getVersions();
Iterable<VectorClock> vci = new TVersionIterable(tvci);
boolean wantKey = syncManager.handleSyncOffer(storeName,
kv.getKey(), vci);
if (wantKey)
srm.addToKeys(kv.bufferForKey());
}
SyncMessage bsm =
new SyncMessage(MessageType.SYNC_REQUEST);
bsm.setSyncRequest(srm);
if (logger.isTraceEnabled()) {
logger.trace("[{}->{}] Sending SyncRequest with {} elements",
new Object[]{getLocalNodeIdString(),
getRemoteNodeIdString(),
srm.getKeysSize()});
}
channel.write(bsm);
} catch (Exception e) {
channel.write(getError(request.getHeader().getTransactionId(),
e, MessageType.SYNC_OFFER));
}
}
@Override
protected void handleSyncRequest(SyncRequestMessage request,
Channel channel) {
rpcService.messageAcked(MessageType.SYNC_OFFER, getRemoteNodeId());
if (!request.isSetKeys()) return;
String storeName = request.getStore().getStoreName();
try {
IStorageEngine<ByteArray, byte[]> store =
syncManager.getRawStore(storeName);
SyncMessage bsm =
TProtocolUtil.getTSyncValueMessage(request.getStore());
SyncValueMessage svm = bsm.getSyncValue();
svm.setResponseTo(request.getHeader().getTransactionId());
svm.getHeader().setTransactionId(rpcService.getTransactionId());
for (ByteBuffer key : request.getKeys()) {
ByteArray keyArray = new ByteArray(key.array());
List<Versioned<byte[]>> values =
store.get(keyArray);
if (values == null || values.size() == 0) continue;
KeyedValues kv =
TProtocolUtil.getTKeyedValues(keyArray, values);
svm.addToValues(kv);
}
if (svm.isSetValues()) {
updateCounter(SyncManager.counterSentValues,
svm.getValuesSize());
rpcService.syncQueue.add(new NodeMessage(getRemoteNodeId(),
bsm));
}
} catch (Exception e) {
channel.write(getError(request.getHeader().getTransactionId(), e,
MessageType.SYNC_REQUEST));
}
}
@Override
protected void handleFullSyncRequest(FullSyncRequestMessage request,
Channel channel) {
startAntientropy();
}
@Override
protected void handleCursorRequest(CursorRequestMessage request,
Channel channel) {
try {
Cursor c = null;
if (request.isSetCursorId()) {
c = syncManager.getCursor(request.getCursorId());
} else {
c = syncManager.newCursor(request.getStoreName());
}
if (c == null) {
throw new SyncException("Unrecognized cursor");
}
CursorResponseMessage m = new CursorResponseMessage();
AsyncMessageHeader header = new AsyncMessageHeader();
header.setTransactionId(request.getHeader().getTransactionId());
m.setHeader(header);
m.setCursorId(c.getCursorId());
if (request.isClose()) {
syncManager.closeCursor(c);
} else {
int i = 0;
while (i < 50 && c.hasNext()) {
Entry<ByteArray, List<Versioned<byte[]>>> e = c.next();
m.addToValues(TProtocolUtil.getTKeyedValues(e.getKey(),
e.getValue()));
i += 1;
}
}
SyncMessage bsm =
new SyncMessage(MessageType.CURSOR_RESPONSE);
bsm.setCursorResponse(m);
channel.write(bsm);
} catch (Exception e) {
channel.write(getError(request.getHeader().getTransactionId(),
e, MessageType.CURSOR_REQUEST));
}
}
@Override
protected void handleRegisterRequest(RegisterRequestMessage request,
Channel channel) {
try {
Scope scope = TProtocolUtil.getScope(request.store.getScope());
if (request.store.isPersist())
syncManager.registerPersistentStore(request.store.storeName,
scope);
else
syncManager.registerStore(request.store.storeName, scope);
RegisterResponseMessage m = new RegisterResponseMessage();
AsyncMessageHeader header = new AsyncMessageHeader();
header.setTransactionId(request.getHeader().getTransactionId());
m.setHeader(header);
SyncMessage bsm =
new SyncMessage(MessageType.REGISTER_RESPONSE);
bsm.setRegisterResponse(m);
channel.write(bsm);
} catch (Exception e) {
channel.write(getError(request.getHeader().getTransactionId(), e,
MessageType.REGISTER_REQUEST));
}
}
@Override
protected void handleClusterJoinRequest(ClusterJoinRequestMessage request,
Channel channel) {
try {
// We can get this message in two circumstances. Either this is
// a totally new node, or this is an existing node that is changing
// its port or IP address. We can tell the difference because the
// node ID and domain ID will already be set for an existing node
ClusterJoinResponseMessage cjrm = new ClusterJoinResponseMessage();
AsyncMessageHeader header = new AsyncMessageHeader();
header.setTransactionId(request.getHeader().getTransactionId());
cjrm.setHeader(header);
org.sdnplatform.sync.thrift.Node tnode = request.getNode();
if (!tnode.isSetNodeId()) {
// allocate a random node ID that's not currently in use
// Note that there is an obvious possible race here if multiple
// nodes join quickly or using different seeds. In this case,
// if you get unlucky you could have the same random node ID
// and then bad things would start to happen. We're essentially
// assuming that node joins are happening one at a time by a
// human; the randomness is a lame attempt to mitigate this race
Random random = new Random();
short newNodeId;
ClusterConfig cc = syncManager.getClusterConfig();
while (true) {
newNodeId = (short)random.nextInt(Short.MAX_VALUE);
if (cc.getNode(newNodeId) == null) break;
}
tnode.setNodeId(newNodeId);
cjrm.setNewNodeId(newNodeId);
}
if (!tnode.isSetDomainId()) {
// for now put the node into its own domain. Once it joins
// the cluster, it can easily change its domain by writing a
// new domain ID into the system node store
tnode.setDomainId(tnode.getNodeId());
}
IStoreClient<Short, Node> nodeStoreClient =
syncManager.getStoreClient(SyncStoreCCProvider.
SYSTEM_NODE_STORE,
Short.class, Node.class);
while (true) {
try {
Versioned<Node> node =
nodeStoreClient.get(tnode.getNodeId());
node.setValue(new Node(tnode.getHostname(),
tnode.getPort(),
tnode.getNodeId(),
tnode.getDomainId()));
nodeStoreClient.put(tnode.getNodeId(), node);
break;
} catch (ObsoleteVersionException e) { }
}
IStorageEngine<ByteArray, byte[]> store =
syncManager.getRawStore(SyncStoreCCProvider.
SYSTEM_NODE_STORE);
IClosableIterator<Entry<ByteArray,
List<Versioned<byte[]>>>> entries = store.entries();
try {
while (entries.hasNext()) {
Entry<ByteArray, List<Versioned<byte[]>>> entry =
entries.next();
KeyedValues kv =
TProtocolUtil.getTKeyedValues(entry.getKey(),
entry.getValue());
cjrm.addToNodeStore(kv);
}
} finally {
entries.close();
}
SyncMessage bsm =
new SyncMessage(MessageType.CLUSTER_JOIN_RESPONSE);
bsm.setClusterJoinResponse(cjrm);
channel.write(bsm);
} catch (Exception e) {
channel.write(getError(request.getHeader().getTransactionId(), e,
MessageType.CLUSTER_JOIN_REQUEST));
}
}
@Override
protected void handleError(ErrorMessage error, Channel channel) {
rpcService.messageAcked(error.getType(), getRemoteNodeId());
updateCounter(SyncManager.counterErrorRemote, 1);
super.handleError(error, channel);
}
// *************************
// AbstractRPCChannelHandler
// *************************
@Override
protected Short getLocalNodeId() {
return syncManager.getLocalNodeId();
}
@Override
protected Short getRemoteNodeId() {
if (remoteNode != null)
return remoteNode.getNodeId();
return null;
}
@Override
protected String getLocalNodeIdString() {
return ""+getLocalNodeId();
}
@Override
protected String getRemoteNodeIdString() {
return ""+getRemoteNodeId();
}
@Override
protected int getTransactionId() {
return rpcService.getTransactionId();
}
@Override
protected AuthScheme getAuthScheme() {
return syncManager.getClusterConfig().getAuthScheme();
}
@Override
protected byte[] getSharedSecret() throws AuthException {
String path = syncManager.getClusterConfig().getKeyStorePath();
String pass = syncManager.getClusterConfig().getKeyStorePassword();
try {
return CryptoUtil.getSharedSecret(path, pass);
} catch (Exception e) {
throw new AuthException("Could not read challenge/response " +
"shared secret from key store " + path, e);
}
}
@Override
protected SyncMessage getError(int transactionId, Exception error,
MessageType type) {
updateCounter(SyncManager.counterErrorProcessing, 1);
return super.getError(transactionId, error, type);
}
// *****************
// Utility functions
// *****************
protected void updateCounter(IDebugCounter counter, int incr) {
counter.updateCounterWithFlush(incr);
}
protected void startAntientropy() {
// Run antientropy in a background task so we don't use up an I/O
// thread. Note that this task will result in lots of traffic
// that will use I/O threads but each of those will be in manageable
// chunks
Runnable arTask = new Runnable() {
@Override
public void run() {
syncManager.antientropy(remoteNode);
}
};
syncManager.getThreadPool().getScheduledExecutor().execute(arTask);
}
protected static class TVersionIterable
implements Iterable<VectorClock> {
final Iterable<org.sdnplatform.sync.thrift.VectorClock> tcvi;
public TVersionIterable(Iterable<org.sdnplatform.sync.thrift.VectorClock> tcvi) {
this.tcvi = tcvi;
}
@Override
public Iterator<VectorClock> iterator() {
final Iterator<org.sdnplatform.sync.thrift.VectorClock> tcs =
tcvi.iterator();
return new Iterator<VectorClock>() {
@Override
public boolean hasNext() {
return tcs.hasNext();
}
@Override
public VectorClock next() {
return TProtocolUtil.getVersion(tcs.next());
}
@Override
public void remove() {
tcs.remove();
}
};
}
}
}