/* * Copyright (C) 2012 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.zookeeper.mock; import com.facebook.collections.RetrieveableSet; import org.apache.zookeeper.CreateMode; import org.apache.zookeeper.KeeperException; import org.apache.zookeeper.WatchedEvent; import org.apache.zookeeper.Watcher; import org.apache.zookeeper.Watcher.Event.EventType; import org.apache.zookeeper.Watcher.Event.KeeperState; import org.apache.zookeeper.data.ACL; import org.apache.zookeeper.data.Stat; import org.apache.zookeeper.server.DataTree; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.EnumMap; import java.util.EnumSet; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; // TODO: what actions trigger version increments? public class MockZooKeeperDataStore { private final AtomicLong nextSessionId = new AtomicLong(0); private final ZNode root = ZNode.createRoot(); private final Map<String, RetrieveableSet<ContextedWatcher>> creationWatchers = new HashMap<String, RetrieveableSet<ContextedWatcher>>(); public long getUniqueSessionId() { return nextSessionId.addAndGet(1); } public synchronized void signalSessionEvent(long sessionId, WatchedEvent watchedEvent) { for (RetrieveableSet<ContextedWatcher> pathWatchers : creationWatchers.values()) { for (ContextedWatcher contextedWatcher : pathWatchers) { if (contextedWatcher.getSessionId() == sessionId) { contextedWatcher.process(watchedEvent); } } } for (ZNode zNode : root) { zNode.signalSessionEvent(sessionId, watchedEvent); } } public synchronized void clearSession(long sessionId) { for (RetrieveableSet<ContextedWatcher> pathWatchers : creationWatchers.values()) { Iterator<ContextedWatcher> iter = pathWatchers.iterator(); while (iter.hasNext()) { ContextedWatcher contextedWatcher = iter.next(); if (contextedWatcher.getSessionId() == sessionId) { iter.remove(); } } } for (ZNode zNode : root) { zNode.clearSession(sessionId); } } public synchronized String create( long sessionId, String path, byte[] data, List<ACL> acl, CreateMode createMode ) throws KeeperException { if (isRootPath(path)) { throw new KeeperException.NodeExistsException(path); } String relativePath = stripRootFromPath(path); String relativeChildPath = root.createDescendant( sessionId, relativePath, data, acl, createMode ); String absChildPath = addRootToPath(relativeChildPath); // Trigger any creation watches that may exist if (creationWatchers.containsKey(absChildPath)) { WatchedEvent watchedEvent = new WatchedEvent( EventType.NodeCreated, KeeperState.SyncConnected, absChildPath ); for (Watcher watcher : creationWatchers.get(absChildPath)) { watcher.process(watchedEvent); } creationWatchers.remove(absChildPath); } return absChildPath; } public synchronized void delete(String path, int expectedVersion) throws KeeperException { if (isRootPath(path)) { throw new KeeperException.BadArgumentsException(path); } String relativePath = stripRootFromPath(path); root.deleteDescendant(relativePath, expectedVersion); } public synchronized Stat exists(long sessionId, String path, Watcher watcher) throws KeeperException { try { ZNode node = isRootPath(path) ? root : root.findDescendant(stripRootFromPath(path)); if (watcher != null) { node.addWatcher(sessionId, watcher, WatchTriggerPolicy.WatchType.EXISTS); } Stat stat = new Stat(); DataTree.copyStat(node.getStat(), stat); return stat; } catch (KeeperException.NoNodeException e) { if (watcher != null) { // Set a watch for this node when it gets created if (!creationWatchers.containsKey(path)) { creationWatchers.put(path, new RetrieveableSet<ContextedWatcher>()); } ContextedWatcher contextedWatcher = new ContextedWatcher( watcher, sessionId, WatchTriggerPolicy.WatchType.EXISTS ); if (!creationWatchers.get(path).contains(contextedWatcher)) { creationWatchers.get(path).add(contextedWatcher); } } return null; } } public synchronized byte[] getData(long sessionId, String path, Watcher watcher, Stat stat) throws KeeperException { ZNode node = isRootPath(path) ? root : root.findDescendant(stripRootFromPath(path)); if (watcher != null) { node.addWatcher(sessionId, watcher, WatchTriggerPolicy.WatchType.GETDATA); } if (stat != null) { DataTree.copyStat(node.getStat(), stat); } return node.getData(); } public synchronized Stat setData(String path, byte[] data, int expectedVersion) throws KeeperException { ZNode node = isRootPath(path) ? root : root.findDescendant(stripRootFromPath(path)); node.setData(data, expectedVersion); Stat stat = new Stat(); DataTree.copyStat(node.getStat(), stat); return stat; } public synchronized List<String> getChildren(long sessionId, String path, Watcher watcher) throws KeeperException { ZNode node = isRootPath(path) ? root : root.findDescendant(stripRootFromPath(path)); if (watcher != null) { node.addWatcher(sessionId, watcher, WatchTriggerPolicy.WatchType.GETCHILDREN); } return new ArrayList<String>(node.getChildren().keySet()); } private static boolean isRootPath(String path) { return path.equals("/"); } private static String stripRootFromPath(String path) { if (!path.startsWith("/")) { throw new IllegalArgumentException("Does not have root: " + path); } // Remove the leading slash for the root node return path.substring(1); } private static String addRootToPath(String path) { if (path.startsWith("/")) { throw new IllegalArgumentException("Already has root: " + path); } // Add the leading slash for the root node return "/" + path; } /** * ZNode: basic node storage unit. Collectively, they form the mock ZooKeeper * data storage tree hierarchy. * * For each ZNode: * - Contains basic tree traversal algorithms stemming from the current ZNode * - Maintains and signals watches set on the node * - Capable of iterating across its entire sub-tree * * Assumptions: * - All paths will be specified relative to the current node. For example, * given the following tree: * A * / \ * B C * / \ * D E * / * F * * If the current node is A, the path we specify to reach F will be: "B/D/F" * If the current node is B, the path we specify to reach F will be: "D/F" * Note: paths should never start or end with a '/' */ private static class ZNode implements Iterable<ZNode> { private final ZNode parent; private final String name; private byte[] data; private List<ACL> acl; private final CreateMode createMode; private final Stat stat = new Stat(); private final AtomicLong nextSeqNum = new AtomicLong(0); private final AtomicInteger version = new AtomicInteger(0); private final Map<String, ZNode> children = new HashMap<String, ZNode>(); private final RetrieveableSet<ContextedWatcher> contextedWatchers = new RetrieveableSet<ContextedWatcher>(); private ZNode( long sessionId, ZNode parent, String name, byte[] data, List<ACL> acl, CreateMode createMode ) { this.parent = parent; this.name = name; this.data = data; this.acl = acl; this.createMode = createMode; stat.setEphemeralOwner(createMode.isEphemeral() ? sessionId : 0); stat.setDataLength((data == null) ? 0 : data.length); stat.setNumChildren(0); stat.setVersion(version.get()); } public static ZNode createRoot() { return new ZNode(0, null, "", new byte[0], null, CreateMode.PERSISTENT); } public void addWatcher( long sessionId, Watcher watcher, WatchTriggerPolicy.WatchType watchType ) { ContextedWatcher contextedWatcher = new ContextedWatcher(watcher, sessionId, watchType); if (contextedWatchers.contains(contextedWatcher)) { contextedWatchers.get(contextedWatcher).merge(contextedWatcher); } else { contextedWatchers.add(contextedWatcher); } } public void clearSession(long sessionId) { // First remove all of your own watches Iterator<ContextedWatcher> iter = contextedWatchers.iterator(); while(iter.hasNext()) { if (iter.next().getSessionId() == sessionId) { iter.remove(); } } // Delete self if node is ephemeral if (stat.getEphemeralOwner() == sessionId) { try { delete(-1); } catch (KeeperException e) { throw new RuntimeException(e); } } // This session should not receive any callbacks as a result of clearing } public void signalSessionEvent(long sessionId, WatchedEvent watchedEvent) { for (ContextedWatcher contextedWatcher : contextedWatchers) { if (contextedWatcher.getSessionId() == sessionId) { contextedWatcher.process(watchedEvent); } } } public void signalNodeEvent(EventType eventType) { assert(eventType != EventType.None); WatchedEvent watchedEvent = new WatchedEvent( eventType, KeeperState.SyncConnected, addRootToPath(getPath()) ); Iterator<ContextedWatcher> iter = contextedWatchers.iterator(); while(iter.hasNext()) { ContextedWatcher contextedWatcher = iter.next(); if (contextedWatcher.shouldTrigger(eventType)) { iter.remove(); // Remove for one use contextedWatcher.process(watchedEvent); } } } public ZNode findDescendant(String path) throws KeeperException { List<String> pathParts = Arrays.asList(path.split("/")); ZNode lastSeenZNode = this; for (String childName : pathParts) { lastSeenZNode = lastSeenZNode.getChildren().get(childName); if (lastSeenZNode == null) { throw new KeeperException.NoNodeException(); } } return lastSeenZNode; } public ZNode findLeafParent(String path) throws KeeperException { if (!path.contains("/")) { // No slashes => this must be the parent return this; } return findDescendant(getLeafParentPath(path)); } private static String getLeafParentPath(String path) { int idx = path.lastIndexOf("/"); if (idx == -1) { throw new IllegalArgumentException("Path does not have parent: " + path); } return path.substring(0, idx); } public String getPath() { ZNode currentNode = this; String path = ""; while (!currentNode.isRoot()) { if (!path.isEmpty()) { path = "/" + path; } path = currentNode.getName() + path; currentNode = currentNode.getParent(); } return path; } private static String getLeafName(String path) { int idx = path.lastIndexOf("/"); if (idx == -1) { return path; } return path.substring(idx+1); } public String createDescendant( long sessionId, String path, byte[] data, List<ACL> acl, CreateMode createMode ) throws KeeperException { ZNode parent = findLeafParent(path); String childName = parent.createChild(sessionId, getLeafName(path), data, acl, createMode); return parent.isRoot() ? childName : parent.getPath() + "/" + childName; } public String createChild( long sessionId, String childName, byte[] data, List<ACL> acl, CreateMode createMode ) throws KeeperException { // Append a sequence number to path if sequential if (createMode.isSequential()) { childName += String.format("%08d", nextSeqNum.addAndGet(1)); } ZNode zNode = new ZNode(sessionId, this, childName, data, acl, createMode); addChild(zNode); zNode.signalNodeEvent(EventType.NodeCreated); return childName; } public void addChild(ZNode zNode) throws KeeperException { if (createMode.isEphemeral()) { throw new KeeperException.NoChildrenForEphemeralsException(); } if (children.containsKey(zNode.getName())) { throw new KeeperException.NodeExistsException(); } children.put(zNode.getName(), zNode); stat.setNumChildren(children.size()); signalNodeEvent(EventType.NodeChildrenChanged); } public void deleteDescendant(String path, int expectedVersion) throws KeeperException { findDescendant(path).delete(expectedVersion); } public void delete(int expectedVersion) throws KeeperException { assert(!isRoot()); if (!getChildren().isEmpty()) { throw new KeeperException.NotEmptyException(); } if (expectedVersion != -1 && getStat().getVersion() != expectedVersion) { throw new KeeperException.BadVersionException(); } if (getParent().children.remove(getName()) == null) { throw new KeeperException.NoNodeException(); } signalNodeEvent(EventType.NodeDeleted); getParent().signalNodeEvent(EventType.NodeChildrenChanged); } public boolean isRoot() { return parent == null; } public ZNode getParent() { return parent; } public String getName() { return name; } public byte[] getData() { return data; } public void setData(byte[] newData, int expectedVersion) throws KeeperException { if (expectedVersion != -1 && getStat().getVersion() != expectedVersion) { throw new KeeperException.BadVersionException(); } this.data = newData; stat.setDataLength((newData == null) ? 0 : newData.length); stat.setVersion(version.addAndGet(1)); signalNodeEvent(EventType.NodeDataChanged); } public List<ACL> getAcl() { return Collections.unmodifiableList(acl); } public Stat getStat() { return stat; } public Map<String, ZNode> getChildren() { return Collections.unmodifiableMap(children); } @Override public Iterator<ZNode> iterator() { return new ZNodeTreeIterator(this); } /** * Iterates across all ZNodes in the sub-tree rooted at the specified node * (will also return the specified ZNode). */ private static class ZNodeTreeIterator implements Iterator<ZNode> { private boolean selfReturned = false; private ZNode initialZNode; private Iterator<ZNode> childIter; private Iterator<ZNode> childTreeIter; private ZNode currentZNode; private ZNodeTreeIterator(ZNode initialZNode) { this.initialZNode = initialZNode; List<ZNode> childrenCopy = new ArrayList<ZNode>(initialZNode.getChildren().values()); childIter = childrenCopy.iterator(); } @Override public boolean hasNext() { if (!selfReturned) { return true; } if (childIter.hasNext()) { return true; } if (childTreeIter != null && childTreeIter.hasNext()) { return true; } return false; } @Override public ZNode next() { if (!selfReturned) { selfReturned = true; currentZNode = initialZNode; return initialZNode; } if (childTreeIter == null || !childTreeIter.hasNext()) { childTreeIter = childIter.next().iterator(); } currentZNode = childTreeIter.next(); return currentZNode; } @Override public void remove() { try { currentZNode.delete(-1); } catch (KeeperException e) { throw new RuntimeException(e); } } } } /** * Encapsulates a Watcher and the context in which it was created */ private static class ContextedWatcher implements Watcher { private final Watcher watcher; private final WatchContext watchContext; private ContextedWatcher( Watcher watcher, long sessionId, WatchTriggerPolicy.WatchType watchType ) { this.watcher = watcher; this.watchContext = new WatchContext(sessionId, watchType); } public long getSessionId() { return watchContext.getSessionId(); } public boolean shouldTrigger(EventType eventType) { return watchContext.shouldTrigger(eventType); } public void merge(ContextedWatcher contextedWatcher) { assert(watcher.equals(contextedWatcher.watcher)); watchContext.merge(contextedWatcher.watchContext); } @Override public void process(WatchedEvent event) { watcher.process(event); } @Override public boolean equals(Object o) { // Equality is only determined by the watcher if (this == o) { return true; } if (!(o instanceof ContextedWatcher)) { return false; } final ContextedWatcher that = (ContextedWatcher) o; if (!watcher.equals(that.watcher)) { return false; } return true; } @Override public int hashCode() { // Hash code only computed from the watcher return watcher.hashCode(); } private static class WatchContext { private final Set<WatchTriggerPolicy.WatchType> watchTypeSet = EnumSet.noneOf(WatchTriggerPolicy.WatchType.class); private long sessionId; private WatchContext(long sessionId, WatchTriggerPolicy.WatchType watchType) { this.sessionId = sessionId; watchTypeSet.add(watchType); } public long getSessionId() { return sessionId; } public boolean shouldTrigger(EventType eventType) { for (WatchTriggerPolicy.WatchType watchType : watchTypeSet) { if (WatchTriggerPolicy.shouldTrigger(watchType, eventType)) { return true; } } return false; } public void merge(WatchContext watchContext) { assert(sessionId == watchContext.getSessionId()); watchTypeSet.addAll(watchContext.watchTypeSet); } } } /** * Defines the ZooKeeper policies for when a particular watch type should be * triggered. */ private static class WatchTriggerPolicy { private enum WatchType { EXISTS, GETDATA, GETCHILDREN; } private static Map<WatchType, Set<EventType>> mapping = constructMapping(); private static Map<WatchType, Set<EventType>> constructMapping() { Map<WatchType, Set<EventType>> mapping = new EnumMap<WatchType, Set<EventType>>(WatchType.class); mapping.put(WatchType.EXISTS, EnumSet.of( EventType.NodeCreated, EventType.NodeDeleted, EventType.NodeDataChanged ) ); mapping.put(WatchType.GETDATA, EnumSet.of( EventType.NodeDeleted, EventType.NodeDataChanged ) ); mapping.put(WatchType.GETCHILDREN, EnumSet.of( EventType.NodeChildrenChanged, EventType.NodeDeleted ) ); return mapping; } public static boolean shouldTrigger(WatchType watchType, EventType eventType) { return mapping.get(watchType).contains(eventType); } } }