package org.sdnplatform.sync.internal.rpc; import java.nio.ByteBuffer; import java.util.Iterator; import java.util.List; import java.util.Random; import java.util.Map.Entry; import net.floodlightcontroller.core.annotations.LogMessageCategory; import net.floodlightcontroller.core.annotations.LogMessageDoc; import net.floodlightcontroller.debugcounter.IDebugCounter; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.ChannelStateEvent; import org.jboss.netty.channel.MessageEvent; import org.sdnplatform.sync.IClosableIterator; import org.sdnplatform.sync.IStoreClient; import org.sdnplatform.sync.IVersion; import org.sdnplatform.sync.Versioned; import org.sdnplatform.sync.ISyncService.Scope; import org.sdnplatform.sync.error.AuthException; import org.sdnplatform.sync.error.ObsoleteVersionException; import org.sdnplatform.sync.error.SyncException; import org.sdnplatform.sync.internal.Cursor; import org.sdnplatform.sync.internal.SyncManager; import org.sdnplatform.sync.internal.config.AuthScheme; import org.sdnplatform.sync.internal.config.ClusterConfig; import org.sdnplatform.sync.internal.config.Node; import org.sdnplatform.sync.internal.config.SyncStoreCCProvider; import org.sdnplatform.sync.internal.rpc.RPCService.NodeMessage; import org.sdnplatform.sync.internal.store.IStorageEngine; import org.sdnplatform.sync.internal.util.ByteArray; import org.sdnplatform.sync.internal.util.CryptoUtil; import org.sdnplatform.sync.internal.version.VectorClock; import org.sdnplatform.sync.thrift.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Channel handler for the RPC service * @author readams */ @LogMessageCategory("State Synchronization") public class RPCChannelHandler extends AbstractRPCChannelHandler { protected static final Logger logger = LoggerFactory.getLogger(RPCChannelHandler.class); protected SyncManager syncManager; protected RPCService rpcService; protected Node remoteNode; protected boolean isClientConnection = false; public RPCChannelHandler(SyncManager syncManager, RPCService rpcService) { super(); this.syncManager = syncManager; this.rpcService = rpcService; } // **************************** // IdleStateAwareChannelHandler // **************************** @Override public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { rpcService.cg.add(ctx.getChannel()); } @Override public void channelDisconnected(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { if (remoteNode != null) { rpcService.disconnectNode(remoteNode.getNodeId()); } } // ****************************************** // AbstractRPCChannelHandler message handlers // ****************************************** @Override public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception { super.messageReceived(ctx, e); } @Override @LogMessageDoc(level="ERROR", message="[{id}->{id}] Attempted connection from unrecognized " + "floodlight node {id}; disconnecting", explanation="A unknown node connected. This can happen " + "transiently if new nodes join the cluster.", recommendation="If the problem persists, verify your cluster" + "configuration and that you don't have unauthorized agents " + "in your network.") protected void handleHello(HelloMessage hello, Channel channel) { if (!hello.isSetNodeId()) { // this is a client connection. Don't set this up as a node // connection isClientConnection = true; return; } remoteNode = syncManager.getClusterConfig().getNode(hello.getNodeId()); if (remoteNode == null) { logger.error("[{}->{}] Attempted connection from unrecognized " + "floodlight node {}; disconnecting", new Object[]{getLocalNodeIdString(), getRemoteNodeIdString(), hello.getNodeId()}); channel.close(); return; } rpcService.nodeConnected(remoteNode.getNodeId(), channel); FullSyncRequestMessage srm = new FullSyncRequestMessage(); AsyncMessageHeader header = new AsyncMessageHeader(); header.setTransactionId(getTransactionId()); srm.setHeader(header); SyncMessage bsm = new SyncMessage(MessageType.FULL_SYNC_REQUEST); channel.write(bsm); // XXX - TODO - if last connection was longer ago than the tombstone // timeout, then we need to do a complete flush and reload of our // state. This is complex though since this applies across entire // partitions and not just single nodes. We'd need to identify the // partition and nuke the smaller half (or lower priority in the case // of an even split). Downstream listeners would need to be able to // handle a state nuke as well. A simple way to nuke would be to ensure // floodlight is restarted in the smaller partition. } @Override protected void handleGetRequest(GetRequestMessage request, Channel channel) { String storeName = request.getStoreName(); try { IStorageEngine<ByteArray, byte[]> store = syncManager.getRawStore(storeName); GetResponseMessage m = new GetResponseMessage(); AsyncMessageHeader header = new AsyncMessageHeader(); header.setTransactionId(request.getHeader().getTransactionId()); m.setHeader(header); List<Versioned<byte[]>> values = store.get(new ByteArray(request.getKey())); for (Versioned<byte[]> value : values) { m.addToValues(TProtocolUtil.getTVersionedValue(value)); } SyncMessage bsm = new SyncMessage(MessageType.GET_RESPONSE); bsm.setGetResponse(m); channel.write(bsm); } catch (Exception e) { channel.write(getError(request.getHeader().getTransactionId(), e, MessageType.GET_REQUEST)); } } @Override protected void handlePutRequest(PutRequestMessage request, Channel channel) { String storeName = request.getStoreName(); try { IStorageEngine<ByteArray, byte[]> store = syncManager.getRawStore(storeName); ByteArray key = new ByteArray(request.getKey()); Versioned<byte[]> value = null; if (request.isSetVersionedValue()) { value = TProtocolUtil. getVersionedValued(request.getVersionedValue()); value.increment(syncManager.getLocalNodeId(), System.currentTimeMillis()); } else if (request.isSetValue()) { byte[] rvalue = request.getValue(); List<IVersion> versions = store.getVersions(key); VectorClock newclock = new VectorClock(); for (IVersion v : versions) { newclock = newclock.merge((VectorClock)v); } newclock = newclock.incremented(syncManager.getLocalNodeId(), System.currentTimeMillis()); value = Versioned.value(rvalue, newclock); } else { throw new SyncException("No value specified for put"); } store.put(key, value); PutResponseMessage m = new PutResponseMessage(); AsyncMessageHeader header = new AsyncMessageHeader(); header.setTransactionId(request.getHeader().getTransactionId()); m.setHeader(header); SyncMessage bsm = new SyncMessage(MessageType.PUT_RESPONSE); bsm.setPutResponse(m); channel.write(bsm); } catch (Exception e) { channel.write(getError(request.getHeader().getTransactionId(), e, MessageType.PUT_REQUEST)); } } @Override protected void handleDeleteRequest(DeleteRequestMessage request, Channel channel) { try { String storeName = request.getStoreName(); IStorageEngine<ByteArray, byte[]> store = syncManager.getRawStore(storeName); ByteArray key = new ByteArray(request.getKey()); VectorClock newclock; if (request.isSetVersion()) { newclock = TProtocolUtil.getVersion(request.getVersion()); } else { newclock = new VectorClock(); List<IVersion> versions = store.getVersions(key); for (IVersion v : versions) { newclock = newclock.merge((VectorClock)v); } } newclock = newclock.incremented(rpcService.syncManager.getLocalNodeId(), System.currentTimeMillis()); Versioned<byte[]> value = Versioned.value(null, newclock); store.put(key, value); DeleteResponseMessage m = new DeleteResponseMessage(); AsyncMessageHeader header = new AsyncMessageHeader(); header.setTransactionId(request.getHeader().getTransactionId()); m.setHeader(header); SyncMessage bsm = new SyncMessage(MessageType.DELETE_RESPONSE); bsm.setDeleteResponse(m); channel.write(bsm); } catch (Exception e) { channel.write(getError(request.getHeader().getTransactionId(), e, MessageType.DELETE_REQUEST)); } } @Override protected void handleSyncValue(SyncValueMessage request, Channel channel) { if (request.isSetResponseTo()) rpcService.messageAcked(MessageType.SYNC_REQUEST, getRemoteNodeId()); try { if (logger.isTraceEnabled()) { logger.trace("[{}->{}] Got syncvalue {}", new Object[]{getLocalNodeIdString(), getRemoteNodeIdString(), request}); } Scope scope = TProtocolUtil.getScope(request.getStore().getScope()); for (KeyedValues kv : request.getValues()) { Iterable<VersionedValue> tvvi = kv.getValues(); Iterable<Versioned<byte[]>> vs = new TVersionedValueIterable(tvvi); syncManager.writeSyncValue(request.getStore().getStoreName(), scope, request.getStore().isPersist(), kv.getKey(), vs); } SyncValueResponseMessage m = new SyncValueResponseMessage(); m.setCount(request.getValuesSize()); AsyncMessageHeader header = new AsyncMessageHeader(); header.setTransactionId(request.getHeader().getTransactionId()); m.setHeader(header); SyncMessage bsm = new SyncMessage(MessageType.SYNC_VALUE_RESPONSE); bsm.setSyncValueResponse(m); updateCounter(SyncManager.counterReceivedValues, request.getValuesSize()); channel.write(bsm); } catch (Exception e) { channel.write(getError(request.getHeader().getTransactionId(), e, MessageType.SYNC_VALUE)); } } @Override protected void handleSyncValueResponse(SyncValueResponseMessage message, Channel channel) { rpcService.messageAcked(MessageType.SYNC_VALUE, getRemoteNodeId()); } @Override protected void handleSyncOffer(SyncOfferMessage request, Channel channel) { try { String storeName = request.getStore().getStoreName(); SyncRequestMessage srm = new SyncRequestMessage(); AsyncMessageHeader header = new AsyncMessageHeader(); header.setTransactionId(request.getHeader().getTransactionId()); srm.setHeader(header); srm.setStore(request.getStore()); for (KeyedVersions kv : request.getVersions()) { Iterable<org.sdnplatform.sync.thrift.VectorClock> tvci = kv.getVersions(); Iterable<VectorClock> vci = new TVersionIterable(tvci); boolean wantKey = syncManager.handleSyncOffer(storeName, kv.getKey(), vci); if (wantKey) srm.addToKeys(kv.bufferForKey()); } SyncMessage bsm = new SyncMessage(MessageType.SYNC_REQUEST); bsm.setSyncRequest(srm); if (logger.isTraceEnabled()) { logger.trace("[{}->{}] Sending SyncRequest with {} elements", new Object[]{getLocalNodeIdString(), getRemoteNodeIdString(), srm.getKeysSize()}); } channel.write(bsm); } catch (Exception e) { channel.write(getError(request.getHeader().getTransactionId(), e, MessageType.SYNC_OFFER)); } } @Override protected void handleSyncRequest(SyncRequestMessage request, Channel channel) { rpcService.messageAcked(MessageType.SYNC_OFFER, getRemoteNodeId()); if (!request.isSetKeys()) return; String storeName = request.getStore().getStoreName(); try { IStorageEngine<ByteArray, byte[]> store = syncManager.getRawStore(storeName); SyncMessage bsm = TProtocolUtil.getTSyncValueMessage(request.getStore()); SyncValueMessage svm = bsm.getSyncValue(); svm.setResponseTo(request.getHeader().getTransactionId()); svm.getHeader().setTransactionId(rpcService.getTransactionId()); for (ByteBuffer key : request.getKeys()) { ByteArray keyArray = new ByteArray(key.array()); List<Versioned<byte[]>> values = store.get(keyArray); if (values == null || values.size() == 0) continue; KeyedValues kv = TProtocolUtil.getTKeyedValues(keyArray, values); svm.addToValues(kv); } if (svm.isSetValues()) { updateCounter(SyncManager.counterSentValues, svm.getValuesSize()); rpcService.syncQueue.add(new NodeMessage(getRemoteNodeId(), bsm)); } } catch (Exception e) { channel.write(getError(request.getHeader().getTransactionId(), e, MessageType.SYNC_REQUEST)); } } @Override protected void handleFullSyncRequest(FullSyncRequestMessage request, Channel channel) { startAntientropy(); } @Override protected void handleCursorRequest(CursorRequestMessage request, Channel channel) { try { Cursor c = null; if (request.isSetCursorId()) { c = syncManager.getCursor(request.getCursorId()); } else { c = syncManager.newCursor(request.getStoreName()); } if (c == null) { throw new SyncException("Unrecognized cursor"); } CursorResponseMessage m = new CursorResponseMessage(); AsyncMessageHeader header = new AsyncMessageHeader(); header.setTransactionId(request.getHeader().getTransactionId()); m.setHeader(header); m.setCursorId(c.getCursorId()); if (request.isClose()) { syncManager.closeCursor(c); } else { int i = 0; while (i < 50 && c.hasNext()) { Entry<ByteArray, List<Versioned<byte[]>>> e = c.next(); m.addToValues(TProtocolUtil.getTKeyedValues(e.getKey(), e.getValue())); i += 1; } } SyncMessage bsm = new SyncMessage(MessageType.CURSOR_RESPONSE); bsm.setCursorResponse(m); channel.write(bsm); } catch (Exception e) { channel.write(getError(request.getHeader().getTransactionId(), e, MessageType.CURSOR_REQUEST)); } } @Override protected void handleRegisterRequest(RegisterRequestMessage request, Channel channel) { try { Scope scope = TProtocolUtil.getScope(request.store.getScope()); if (request.store.isPersist()) syncManager.registerPersistentStore(request.store.storeName, scope); else syncManager.registerStore(request.store.storeName, scope); RegisterResponseMessage m = new RegisterResponseMessage(); AsyncMessageHeader header = new AsyncMessageHeader(); header.setTransactionId(request.getHeader().getTransactionId()); m.setHeader(header); SyncMessage bsm = new SyncMessage(MessageType.REGISTER_RESPONSE); bsm.setRegisterResponse(m); channel.write(bsm); } catch (Exception e) { channel.write(getError(request.getHeader().getTransactionId(), e, MessageType.REGISTER_REQUEST)); } } @Override protected void handleClusterJoinRequest(ClusterJoinRequestMessage request, Channel channel) { try { // We can get this message in two circumstances. Either this is // a totally new node, or this is an existing node that is changing // its port or IP address. We can tell the difference because the // node ID and domain ID will already be set for an existing node ClusterJoinResponseMessage cjrm = new ClusterJoinResponseMessage(); AsyncMessageHeader header = new AsyncMessageHeader(); header.setTransactionId(request.getHeader().getTransactionId()); cjrm.setHeader(header); org.sdnplatform.sync.thrift.Node tnode = request.getNode(); if (!tnode.isSetNodeId()) { // allocate a random node ID that's not currently in use // Note that there is an obvious possible race here if multiple // nodes join quickly or using different seeds. In this case, // if you get unlucky you could have the same random node ID // and then bad things would start to happen. We're essentially // assuming that node joins are happening one at a time by a // human; the randomness is a lame attempt to mitigate this race Random random = new Random(); short newNodeId; ClusterConfig cc = syncManager.getClusterConfig(); while (true) { newNodeId = (short)random.nextInt(Short.MAX_VALUE); if (cc.getNode(newNodeId) == null) break; } tnode.setNodeId(newNodeId); cjrm.setNewNodeId(newNodeId); } if (!tnode.isSetDomainId()) { // for now put the node into its own domain. Once it joins // the cluster, it can easily change its domain by writing a // new domain ID into the system node store tnode.setDomainId(tnode.getNodeId()); } IStoreClient<Short, Node> nodeStoreClient = syncManager.getStoreClient(SyncStoreCCProvider. SYSTEM_NODE_STORE, Short.class, Node.class); while (true) { try { Versioned<Node> node = nodeStoreClient.get(tnode.getNodeId()); node.setValue(new Node(tnode.getHostname(), tnode.getPort(), tnode.getNodeId(), tnode.getDomainId())); nodeStoreClient.put(tnode.getNodeId(), node); break; } catch (ObsoleteVersionException e) { } } IStorageEngine<ByteArray, byte[]> store = syncManager.getRawStore(SyncStoreCCProvider. SYSTEM_NODE_STORE); IClosableIterator<Entry<ByteArray, List<Versioned<byte[]>>>> entries = store.entries(); try { while (entries.hasNext()) { Entry<ByteArray, List<Versioned<byte[]>>> entry = entries.next(); KeyedValues kv = TProtocolUtil.getTKeyedValues(entry.getKey(), entry.getValue()); cjrm.addToNodeStore(kv); } } finally { entries.close(); } SyncMessage bsm = new SyncMessage(MessageType.CLUSTER_JOIN_RESPONSE); bsm.setClusterJoinResponse(cjrm); channel.write(bsm); } catch (Exception e) { channel.write(getError(request.getHeader().getTransactionId(), e, MessageType.CLUSTER_JOIN_REQUEST)); } } @Override protected void handleError(ErrorMessage error, Channel channel) { rpcService.messageAcked(error.getType(), getRemoteNodeId()); updateCounter(SyncManager.counterErrorRemote, 1); super.handleError(error, channel); } // ************************* // AbstractRPCChannelHandler // ************************* @Override protected Short getLocalNodeId() { return syncManager.getLocalNodeId(); } @Override protected Short getRemoteNodeId() { if (remoteNode != null) return remoteNode.getNodeId(); return null; } @Override protected String getLocalNodeIdString() { return ""+getLocalNodeId(); } @Override protected String getRemoteNodeIdString() { return ""+getRemoteNodeId(); } @Override protected int getTransactionId() { return rpcService.getTransactionId(); } @Override protected AuthScheme getAuthScheme() { return syncManager.getClusterConfig().getAuthScheme(); } @Override protected byte[] getSharedSecret() throws AuthException { String path = syncManager.getClusterConfig().getKeyStorePath(); String pass = syncManager.getClusterConfig().getKeyStorePassword(); try { return CryptoUtil.getSharedSecret(path, pass); } catch (Exception e) { throw new AuthException("Could not read challenge/response " + "shared secret from key store " + path, e); } } @Override protected SyncMessage getError(int transactionId, Exception error, MessageType type) { updateCounter(SyncManager.counterErrorProcessing, 1); return super.getError(transactionId, error, type); } // ***************** // Utility functions // ***************** protected void updateCounter(IDebugCounter counter, int incr) { counter.updateCounterWithFlush(incr); } protected void startAntientropy() { // Run antientropy in a background task so we don't use up an I/O // thread. Note that this task will result in lots of traffic // that will use I/O threads but each of those will be in manageable // chunks Runnable arTask = new Runnable() { @Override public void run() { syncManager.antientropy(remoteNode); } }; syncManager.getThreadPool().getScheduledExecutor().execute(arTask); } protected static class TVersionIterable implements Iterable<VectorClock> { final Iterable<org.sdnplatform.sync.thrift.VectorClock> tcvi; public TVersionIterable(Iterable<org.sdnplatform.sync.thrift.VectorClock> tcvi) { this.tcvi = tcvi; } @Override public Iterator<VectorClock> iterator() { final Iterator<org.sdnplatform.sync.thrift.VectorClock> tcs = tcvi.iterator(); return new Iterator<VectorClock>() { @Override public boolean hasNext() { return tcs.hasNext(); } @Override public VectorClock next() { return TProtocolUtil.getVersion(tcs.next()); } @Override public void remove() { tcs.remove(); } }; } } }