/**
* Copyright 2016 Yahoo Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zookeeper;
import java.lang.reflect.Constructor;
import java.util.List;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.bookkeeper.mledger.util.Pair;
import org.apache.zookeeper.AsyncCallback.Children2Callback;
import org.apache.zookeeper.AsyncCallback.ChildrenCallback;
import org.apache.zookeeper.AsyncCallback.DataCallback;
import org.apache.zookeeper.AsyncCallback.StatCallback;
import org.apache.zookeeper.AsyncCallback.StringCallback;
import org.apache.zookeeper.AsyncCallback.VoidCallback;
import org.apache.zookeeper.Watcher.Event.EventType;
import org.apache.zookeeper.Watcher.Event.KeeperState;
import org.apache.zookeeper.data.ACL;
import org.apache.zookeeper.data.Stat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimaps;
import com.google.common.collect.SetMultimap;
import com.google.common.collect.Sets;
import io.netty.util.concurrent.DefaultThreadFactory;
import sun.reflect.ReflectionFactory;
@SuppressWarnings({ "deprecation", "restriction", "rawtypes" })
public class MockZooKeeper extends ZooKeeper {
private TreeMap<String, Pair<String, Integer>> tree;
private SetMultimap<String, Watcher> watchers;
private boolean stopped;
private boolean alwaysFail = false;
private ExecutorService executor;
private AtomicInteger stepsToFail;
private KeeperException.Code failReturnCode;
private Watcher sessionWatcher;
private long sessionId = 0L;
private int readOpDelayMs;
public static MockZooKeeper newInstance() {
return newInstance(null);
}
public static MockZooKeeper newInstance(ExecutorService executor) {
return newInstance(executor, -1);
}
public static MockZooKeeper newInstance(ExecutorService executor, int readOpDelayMs) {
try {
ReflectionFactory rf = ReflectionFactory.getReflectionFactory();
Constructor objDef = Object.class.getDeclaredConstructor(new Class[0]);
Constructor intConstr = rf.newConstructorForSerialization(MockZooKeeper.class, objDef);
MockZooKeeper zk = MockZooKeeper.class.cast(intConstr.newInstance());
zk.init(executor);
zk.readOpDelayMs = readOpDelayMs;
return zk;
} catch (RuntimeException e) {
throw e;
} catch (Exception e) {
throw new IllegalStateException("Cannot create object", e);
}
}
private void init(ExecutorService executor) {
tree = Maps.newTreeMap();
if (executor != null) {
this.executor = executor;
} else {
this.executor = Executors.newFixedThreadPool(1, new DefaultThreadFactory("mock-zookeeper"));
}
SetMultimap<String, Watcher> w = HashMultimap.create();
watchers = Multimaps.synchronizedSetMultimap(w);
stopped = false;
stepsToFail = new AtomicInteger(-1);
failReturnCode = KeeperException.Code.OK;
}
private MockZooKeeper(String quorum) throws Exception {
// This constructor is never called
super(quorum, 1, new Watcher() {
@Override
public void process(WatchedEvent event) {
}
});
assert false;
}
@Override
public States getState() {
return States.CONNECTED;
}
@Override
public synchronized void register(Watcher watcher) {
sessionWatcher = watcher;
}
@Override
public synchronized String create(String path, byte[] data, List<ACL> acl, CreateMode createMode)
throws KeeperException, InterruptedException {
checkProgrammedFail();
if (stopped)
throw new KeeperException.ConnectionLossException();
if (tree.containsKey(path)) {
throw new KeeperException.NodeExistsException(path);
}
final String parent = path.substring(0, path.lastIndexOf("/"));
if (!parent.isEmpty() && !tree.containsKey(parent)) {
throw new KeeperException.NoNodeException();
}
if (createMode == CreateMode.EPHEMERAL_SEQUENTIAL || createMode == CreateMode.PERSISTENT_SEQUENTIAL) {
String parentData = tree.get(parent).first;
int parentVersion = tree.get(parent).second;
path = path + parentVersion;
// Update parent version
tree.put(parent, Pair.create(parentData, parentVersion + 1));
}
tree.put(path, Pair.create(new String(data), 0));
if (!parent.isEmpty()) {
final Set<Watcher> toNotifyParent = Sets.newHashSet();
toNotifyParent.addAll(watchers.get(parent));
executor.execute(() -> {
toNotifyParent.forEach(watcher -> watcher
.process(new WatchedEvent(EventType.NodeChildrenChanged, KeeperState.SyncConnected, parent)));
});
}
return path;
}
@Override
public synchronized void create(final String path, final byte[] data, final List<ACL> acl, CreateMode createMode,
final StringCallback cb, final Object ctx) {
if (stopped) {
cb.processResult(KeeperException.Code.CONNECTIONLOSS.intValue(), path, ctx, null);
return;
}
executor.execute(() -> {
String parent = path.substring(0, path.lastIndexOf("/"));
synchronized (MockZooKeeper.this) {
if (getProgrammedFailStatus()) {
cb.processResult(failReturnCode.intValue(), path, ctx, null);
} else if (stopped) {
cb.processResult(KeeperException.Code.CONNECTIONLOSS.intValue(), path, ctx, null);
} else if (tree.containsKey(path)) {
cb.processResult(KeeperException.Code.NODEEXISTS.intValue(), path, ctx, null);
} else if (!parent.isEmpty() && !tree.containsKey(parent)) {
cb.processResult(KeeperException.Code.NONODE.intValue(), path, ctx, null);
} else {
tree.put(path, Pair.create(new String(data), 0));
cb.processResult(0, path, ctx, null);
if (!parent.isEmpty()) {
watchers.get(parent).forEach(watcher -> watcher.process(
new WatchedEvent(EventType.NodeChildrenChanged, KeeperState.SyncConnected, parent)));
}
}
}
});
}
@Override
public synchronized byte[] getData(String path, Watcher watcher, Stat stat) throws KeeperException {
checkProgrammedFail();
Pair<String, Integer> value = tree.get(path);
if (value == null) {
throw new KeeperException.NoNodeException(path);
} else {
if (watcher != null) {
watchers.put(path, watcher);
}
if (stat != null) {
stat.setVersion(value.second);
}
return value.first.getBytes();
}
}
@Override
public void getData(final String path, boolean watch, final DataCallback cb, final Object ctx) {
executor.execute(() -> {
checkReadOpDelay();
if (getProgrammedFailStatus()) {
cb.processResult(failReturnCode.intValue(), path, ctx, null, null);
return;
} else if (stopped) {
cb.processResult(KeeperException.Code.ConnectionLoss, path, ctx, null, null);
return;
}
Pair<String, Integer> value;
synchronized (MockZooKeeper.this) {
value = tree.get(path);
}
if (value == null) {
cb.processResult(KeeperException.Code.NoNode, path, ctx, null, null);
} else {
Stat stat = new Stat();
stat.setVersion(value.second);
cb.processResult(0, path, ctx, value.first.getBytes(), stat);
}
});
}
@Override
public void getData(final String path, final Watcher watcher, final DataCallback cb, final Object ctx) {
executor.execute(() -> {
checkReadOpDelay();
synchronized (MockZooKeeper.this) {
if (getProgrammedFailStatus()) {
cb.processResult(failReturnCode.intValue(), path, ctx, null, null);
return;
} else if (stopped) {
cb.processResult(KeeperException.Code.CONNECTIONLOSS.intValue(), path, ctx, null, null);
return;
}
Pair<String, Integer> value = tree.get(path);
if (value == null) {
cb.processResult(KeeperException.Code.NONODE.intValue(), path, ctx, null, null);
} else {
if (watcher != null) {
watchers.put(path, watcher);
}
Stat stat = new Stat();
stat.setVersion(value.second);
cb.processResult(0, path, ctx, value.first.getBytes(), stat);
}
}
});
}
@Override
public void getChildren(final String path, final Watcher watcher, final ChildrenCallback cb, final Object ctx) {
executor.execute(() -> {
synchronized (MockZooKeeper.this) {
if (getProgrammedFailStatus()) {
cb.processResult(failReturnCode.intValue(), path, ctx, null);
return;
} else if (stopped) {
cb.processResult(KeeperException.Code.ConnectionLoss, path, ctx, null);
return;
}
List<String> children = Lists.newArrayList();
for (String item : tree.tailMap(path).keySet()) {
if (!item.startsWith(path)) {
break;
} else {
if (path.length() >= item.length()) {
continue;
}
String child = item.substring(path.length() + 1);
if (!child.contains("/")) {
children.add(child);
}
}
}
cb.processResult(0, path, ctx, children);
if (watcher != null) {
watchers.put(path, watcher);
}
}
});
}
@Override
public synchronized List<String> getChildren(String path, Watcher watcher) throws KeeperException {
checkProgrammedFail();
if (!tree.containsKey(path)) {
throw new KeeperException.NoNodeException();
}
List<String> children = Lists.newArrayList();
for (String item : tree.tailMap(path).keySet()) {
if (!item.startsWith(path)) {
break;
} else {
if (path.length() >= item.length()) {
continue;
}
String child = item.substring(path.length() + 1);
if (!child.contains("/")) {
children.add(child);
}
}
}
if (watcher != null) {
watchers.put(path, watcher);
}
return children;
}
@Override
public synchronized List<String> getChildren(String path, boolean watch)
throws KeeperException, InterruptedException {
checkProgrammedFail();
if (stopped) {
throw new KeeperException.ConnectionLossException();
} else if (!tree.containsKey(path)) {
throw new KeeperException.NoNodeException();
}
List<String> children = Lists.newArrayList();
for (String item : tree.tailMap(path).keySet()) {
if (!item.startsWith(path)) {
break;
} else {
if (path.length() >= item.length()) {
continue;
}
String child = item.substring(path.length() + 1);
if (!child.contains("/")) {
children.add(child);
}
}
}
return children;
}
@Override
public void getChildren(final String path, boolean watcher, final Children2Callback cb, final Object ctx) {
executor.execute(() -> {
synchronized (MockZooKeeper.this) {
if (getProgrammedFailStatus()) {
cb.processResult(failReturnCode.intValue(), path, ctx, null, null);
return;
} else if (stopped) {
cb.processResult(KeeperException.Code.ConnectionLoss, path, ctx, null, null);
return;
} else if (!tree.containsKey(path)) {
cb.processResult(KeeperException.Code.NoNode, path, ctx, null, null);
return;
}
log.debug("getChildren path={}", path);
List<String> children = Lists.newArrayList();
for (String item : tree.tailMap(path).keySet()) {
log.debug("Checking path {}", item);
if (!item.startsWith(path)) {
break;
} else if (item.equals(path)) {
continue;
} else {
String child = item.substring(path.length() + 1);
log.debug("child: '{}'", child);
if (!child.contains("/")) {
children.add(child);
}
}
}
log.debug("getChildren done path={} result={}", path, children);
cb.processResult(0, path, ctx, children, new Stat());
}
});
}
@Override
public synchronized Stat exists(String path, boolean watch) throws KeeperException, InterruptedException {
checkProgrammedFail();
if (stopped)
throw new KeeperException.ConnectionLossException();
if (tree.containsKey(path)) {
Stat stat = new Stat();
stat.setVersion(tree.get(path).second);
return stat;
} else {
return null;
}
}
@Override
public synchronized Stat exists(String path, Watcher watcher) throws KeeperException, InterruptedException {
checkProgrammedFail();
if (stopped)
throw new KeeperException.ConnectionLossException();
if (watcher != null) {
watchers.put(path, watcher);
}
if (tree.containsKey(path)) {
Stat stat = new Stat();
stat.setVersion(tree.get(path).second);
return stat;
} else {
return null;
}
}
public void exists(String path, boolean watch, StatCallback cb, Object ctx) {
executor.execute(() -> {
synchronized (this) {
if (getProgrammedFailStatus()) {
cb.processResult(failReturnCode.intValue(), path, ctx, null);
return;
} else if (stopped) {
cb.processResult(KeeperException.Code.ConnectionLoss, path, ctx, null);
return;
}
if (tree.containsKey(path)) {
cb.processResult(0, path, ctx, new Stat());
} else {
cb.processResult(KeeperException.Code.NoNode, path, ctx, null);
}
}
});
}
@Override
public void sync(String path, VoidCallback cb, Object ctx) {
executor.execute(() -> {
synchronized (this) {
if (getProgrammedFailStatus()) {
cb.processResult(failReturnCode.intValue(), path, ctx);
return;
} else if (stopped) {
cb.processResult(KeeperException.Code.ConnectionLoss, path, ctx);
return;
}
cb.processResult(0, path, ctx);
}
});
}
@Override
public Stat setData(final String path, byte[] data, int version) throws KeeperException, InterruptedException {
final Set<Watcher> toNotify = Sets.newHashSet();
int newVersion;
synchronized (this) {
checkProgrammedFail();
if (stopped) {
throw new KeeperException.ConnectionLossException();
}
if (!tree.containsKey(path)) {
throw new KeeperException.NoNodeException();
}
int currentVersion = tree.get(path).second;
// Check version
if (version != -1 && version != currentVersion) {
throw new KeeperException.BadVersionException(path);
}
newVersion = currentVersion + 1;
log.debug("[{}] Updating -- current version: {}", path, currentVersion);
tree.put(path, Pair.create(new String(data), newVersion));
toNotify.addAll(watchers.get(path));
watchers.removeAll(path);
}
executor.execute(() -> {
toNotify.forEach(watcher -> watcher
.process(new WatchedEvent(EventType.NodeDataChanged, KeeperState.SyncConnected, path)));
});
Stat stat = new Stat();
stat.setVersion(newVersion);
return stat;
}
@Override
public synchronized void setData(final String path, final byte[] data, int version, final StatCallback cb,
final Object ctx) {
if (stopped) {
cb.processResult(KeeperException.Code.ConnectionLoss, path, ctx, null);
return;
}
executor.execute(() -> {
final Set<Watcher> toNotify = Sets.newHashSet();
synchronized (MockZooKeeper.this) {
if (getProgrammedFailStatus()) {
cb.processResult(failReturnCode.intValue(), path, ctx, null);
return;
} else if (stopped) {
cb.processResult(KeeperException.Code.ConnectionLoss, path, ctx, null);
return;
}
if (!tree.containsKey(path)) {
cb.processResult(KeeperException.Code.NoNode, path, ctx, null);
return;
}
int currentVersion = tree.get(path).second;
// Check version
if (version != -1 && version != currentVersion) {
log.debug("[{}] Current version: {} -- Expected: {}", path, currentVersion, version);
cb.processResult(KeeperException.Code.BadVersion, path, ctx, null);
return;
}
int newVersion = currentVersion + 1;
log.debug("[{}] Updating -- current version: {}", path, currentVersion);
tree.put(path, Pair.create(new String(data), newVersion));
Stat stat = new Stat();
stat.setVersion(newVersion);
cb.processResult(0, path, ctx, stat);
toNotify.addAll(watchers.get(path));
watchers.removeAll(path);
}
for (Watcher watcher : toNotify) {
watcher.process(new WatchedEvent(EventType.NodeDataChanged, KeeperState.SyncConnected, path));
}
});
}
@Override
public void delete(final String path, int version) throws InterruptedException, KeeperException {
checkProgrammedFail();
final Set<Watcher> toNotifyDelete;
final Set<Watcher> toNotifyParent;
final String parent;
synchronized (this) {
if (stopped) {
throw new KeeperException.ConnectionLossException();
} else if (!tree.containsKey(path)) {
throw new KeeperException.NoNodeException(path);
} else if (hasChildren(path)) {
throw new KeeperException.NotEmptyException(path);
}
if (version != -1) {
int currentVersion = tree.get(path).second;
if (version != currentVersion) {
throw new KeeperException.BadVersionException(path);
}
}
tree.remove(path);
toNotifyDelete = Sets.newHashSet();
toNotifyDelete.addAll(watchers.get(path));
toNotifyParent = Sets.newHashSet();
parent = path.substring(0, path.lastIndexOf("/"));
if (!parent.isEmpty()) {
toNotifyParent.addAll(watchers.get(parent));
}
watchers.removeAll(path);
}
executor.execute(() -> {
if (stopped) {
return;
}
for (Watcher watcher1 : toNotifyDelete) {
watcher1.process(new WatchedEvent(EventType.NodeDeleted, KeeperState.SyncConnected, path));
}
for (Watcher watcher2 : toNotifyParent) {
watcher2.process(new WatchedEvent(EventType.NodeChildrenChanged, KeeperState.SyncConnected, parent));
}
});
}
@Override
public synchronized void delete(final String path, int version, final VoidCallback cb, final Object ctx) {
if (executor.isShutdown()) {
cb.processResult(KeeperException.Code.SESSIONEXPIRED.intValue(), path, ctx);
return;
}
final Set<Watcher> toNotifyDelete = Sets.newHashSet();
toNotifyDelete.addAll(watchers.get(path));
final Set<Watcher> toNotifyParent = Sets.newHashSet();
final String parent = path.substring(0, path.lastIndexOf("/"));
if (!parent.isEmpty()) {
toNotifyParent.addAll(watchers.get(parent));
}
executor.execute(() -> {
if (getProgrammedFailStatus()) {
cb.processResult(failReturnCode.intValue(), path, ctx);
} else if (stopped) {
cb.processResult(KeeperException.Code.CONNECTIONLOSS.intValue(), path, ctx);
} else if (!tree.containsKey(path)) {
cb.processResult(KeeperException.Code.NONODE.intValue(), path, ctx);
} else if (hasChildren(path)) {
cb.processResult(KeeperException.Code.NOTEMPTY.intValue(), path, ctx);
} else {
if (version != -1) {
int currentVersion = tree.get(path).second;
if (version != currentVersion) {
cb.processResult(KeeperException.Code.BADVERSION.intValue(), path, ctx);
return;
}
}
tree.remove(path);
cb.processResult(0, path, ctx);
toNotifyDelete.forEach(watcher -> watcher
.process(new WatchedEvent(EventType.NodeDeleted, KeeperState.SyncConnected, path)));
toNotifyParent.forEach(watcher -> watcher
.process(new WatchedEvent(EventType.NodeChildrenChanged, KeeperState.SyncConnected, parent)));
}
});
watchers.removeAll(path);
}
@Override
public void close() throws InterruptedException {
}
public synchronized void shutdown() throws InterruptedException {
stopped = true;
tree.clear();
watchers.clear();
executor.shutdownNow();
}
void checkProgrammedFail() throws KeeperException {
if (stepsToFail.getAndDecrement() == 0 || this.alwaysFail) {
throw KeeperException.create(failReturnCode);
}
}
boolean getProgrammedFailStatus() {
return stepsToFail.getAndDecrement() == 0;
}
public void failNow(KeeperException.Code rc) {
failAfter(0, rc);
}
public void setAlwaysFail(KeeperException.Code rc) {
this.alwaysFail = true;
this.failReturnCode = rc;
}
public void unsetAlwaysFail() {
this.alwaysFail = false;
}
public void failAfter(int steps, KeeperException.Code rc) {
stepsToFail.set(steps);
failReturnCode = rc;
}
public void setSessionId(long id) {
sessionId = id;
}
@Override
public long getSessionId() {
return sessionId;
}
private boolean hasChildren(String path) {
return !tree.subMap(path + '/', path + '0').isEmpty();
}
@Override
public String toString() {
return "MockZookeeper";
}
private void checkReadOpDelay() {
if (readOpDelayMs > 0) {
try {
Thread.sleep(readOpDelayMs);
} catch (InterruptedException e) {
// Ok
}
}
}
private static final Logger log = LoggerFactory.getLogger(MockZooKeeper.class);
}