/**
* Licensed to the zk1931 under one or more contributor license
* agreements. See the NOTICE file distributed with this work
* for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.zk1931.jzab;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.ByteBuffer;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
import java.util.Iterator;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.junit.Assert;
import org.junit.Test;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
class SnapshotStateMachine implements StateMachine {
private static final Logger LOG =
LoggerFactory.getLogger(SnapshotStateMachine.class);
ConcurrentHashMap<String, String> state =
new ConcurrentHashMap<String, String>();
CountDownLatch txnsCount;
Semaphore semMembership = new Semaphore(0);
Semaphore semSnapshot = new Semaphore(0);
SnapshotStateMachine(int nTxns) {
txnsCount = new CountDownLatch(nTxns);
}
@Override
public ByteBuffer preprocess(Zxid zxid, ByteBuffer message) {
return message;
}
@Override
public void deliver(Zxid zxid, ByteBuffer stateUpdate, String clientId,
Object ctx) {
byte[] buffer = new byte[stateUpdate.remaining()];
stateUpdate.get(buffer);
String key = new String(buffer);
LOG.debug("Delivering {}", key);
this.state.put(key, "value");
txnsCount.countDown();
}
@Override
public void flushed(Zxid zxid, ByteBuffer flushReq, Object ctx) {}
@Override
public void save(FileOutputStream fos) {
LOG.debug("SAVE is called.");
try {
ObjectOutputStream out = new ObjectOutputStream(fos);
out.writeObject(state);
} catch (IOException e) {
LOG.error("Caught exception", e);
}
}
@Override
public void snapshotDone(String filePath, Object ctx) {
LOG.debug("Snapshot is stored at {}", filePath);
semSnapshot.release();
}
@Override
public void restore(FileInputStream fis) {
LOG.debug("RESTORE is called.");
try {
ObjectInputStream oin = new ObjectInputStream(fis);
ConcurrentHashMap<?, ?> map = (ConcurrentHashMap<?, ?>)oin.readObject();
state = new ConcurrentHashMap<String, String>();
Iterator it = map.entrySet().iterator();
while (it.hasNext()) {
Map.Entry pairs = (Map.Entry)it.next();
state.put((String)pairs.getKey(), (String)pairs.getValue());
}
LOG.debug("The size of map after recovery from snapshot file is {}",
state.size());
} catch (Exception e) {
LOG.error("Caught exception", e);
}
}
@Override
public void removed(String serverId, Object ctx) {
}
@Override
public void recovering(PendingRequests pendings) {}
@Override
public void leading(Set<String> followers, Set<String> members) {
semMembership.release();
}
@Override
public void following(String ld, Set<String> members) {
semMembership.release();
}
void waitMemberChanged() throws InterruptedException {
semMembership.acquire();
}
void waitSnapshot() throws InterruptedException {
semSnapshot.acquire();
}
}
/**
* Tests for Snapshot.
*/
public class SnapshotTest extends TestBase {
private static final Logger LOG =
LoggerFactory.getLogger(SnapshotTest.class);
@Test(timeout=10000)
public void testSnapshotSingleServer() throws Exception {
final int nTxns = 20;
QuorumTestCallback cb1 = new QuorumTestCallback();
SnapshotStateMachine st1 = new SnapshotStateMachine(nTxns);
String server = getUniqueHostPort();
ZabConfig config = new ZabConfig();
// For testing purpose, set the threshold to 32 bytes..
config.setLogDir(getDirectory().getPath());
Zab zab = new Zab(st1, config, server, server);
st1.waitMemberChanged();
for (int i = 0; i < nTxns; ++i) {
zab.send(ByteBuffer.wrap(("txns" + i).getBytes()), null);
}
st1.txnsCount.await();
// Take the snapshot after all transaction gets delivered.
zab.takeSnapshot(null);
st1.waitSnapshot();
Thread.sleep(1000);
zab.shutdown();
SnapshotStateMachine stNew = new SnapshotStateMachine(nTxns);
zab = new Zab(stNew, config);
stNew.waitMemberChanged();
// Make sure restored state is consistent.
Assert.assertEquals(st1.state, stNew.state);
zab.shutdown();
}
@Test(timeout=20000)
public void testSnapshotCluster() throws Exception {
final int nTxns = 20;
QuorumTestCallback cb1 = new QuorumTestCallback();
SnapshotStateMachine st1 = new SnapshotStateMachine(nTxns);
QuorumTestCallback cb2 = new QuorumTestCallback();
SnapshotStateMachine st2 = new SnapshotStateMachine(nTxns);
String server1 = getUniqueHostPort();
String server2 = getUniqueHostPort();
ZabConfig config1 = new ZabConfig();
ZabConfig config2 = new ZabConfig();
config1.setLogDir(getDirectory().getPath() + File.separator + server1);
config2.setLogDir(getDirectory().getPath() + File.separator + server2);
Zab zab1 = new Zab(st1, config1, server1, server1);
st1.waitMemberChanged();
Zab zab2 = new Zab(st2, config2, server2, server1);
st2.waitMemberChanged();
int snapIdx = new Random().nextInt(nTxns);
for (int i = 0; i < nTxns; ++i) {
zab1.send(ByteBuffer.wrap(("txns" + i).getBytes()), null);
Thread.sleep(5);
if (i == snapIdx) {
zab1.takeSnapshot(null);
zab2.takeSnapshot(null);
st1.waitSnapshot();
st2.waitSnapshot();
}
}
st1.txnsCount.await();
st2.txnsCount.await();
// Shutdown both servers.
zab1.shutdown();
zab2.shutdown();
// Restarts them.
SnapshotStateMachine stNew1 = new SnapshotStateMachine(nTxns);
SnapshotStateMachine stNew2 = new SnapshotStateMachine(nTxns);
zab1 = new Zab(stNew1, config1);
zab2 = new Zab(stNew2, config2);
stNew1.waitMemberChanged();
stNew2.waitMemberChanged();
// Make sure the states are consistent.
Assert.assertEquals(st1.state, stNew1.state);
Assert.assertEquals(st2.state, stNew2.state);
Assert.assertEquals(stNew1.state, stNew2.state);
zab1.shutdown();
zab2.shutdown();
}
@Test(timeout=20000)
public void testSnapshotSynchronizationCase1() throws Exception {
// Starts server1, sends transactions txn1,txn2 ... txnn.
// Starts server2 joins server1, the snapshot will be used to synchronize
// server2. In the end, we verify the two state machines have the same state
final int nTxns = 50;
QuorumTestCallback cb1 = new QuorumTestCallback();
SnapshotStateMachine st1 = new SnapshotStateMachine(nTxns);
QuorumTestCallback cb2 = new QuorumTestCallback();
SnapshotStateMachine st2 = new SnapshotStateMachine(nTxns);
String server1 = getUniqueHostPort();
String server2 = getUniqueHostPort();
ZabConfig config1 = new ZabConfig();
config1.setLogDir(getDirectory().getPath() + File.separator + server1);
ZabConfig config2 = new ZabConfig();
config2.setLogDir(getDirectory().getPath() + File.separator + server2);
Zab zab1 = new Zab(st1, config1, server1, server1);
st1.waitMemberChanged();
int snapIdx = new Random().nextInt(nTxns);
for (int i = 0; i < nTxns; ++i) {
zab1.send(ByteBuffer.wrap(("txns" + i).getBytes()), null);
// Sleep a while to avoid all the transactions batch together.
Thread.sleep(5);
if (i == snapIdx) {
zab1.takeSnapshot(null);
st1.waitSnapshot();
}
}
st1.txnsCount.await();
// Server2 joins in.
Zab zab2 = new Zab(st2, config2, server2, server1);
st2.waitMemberChanged();
Assert.assertEquals(st1.state, st2.state);
zab2.shutdown();
st2 = new SnapshotStateMachine(nTxns);
// zab2 recovers.
zab2 = new Zab(st2, config2);
st2.waitMemberChanged();
// After recovery, we verify they still have same states.
Assert.assertEquals(st1.state, st2.state);
zab2.shutdown();
zab1.shutdown();
}
@Test(timeout=20000)
public void testSnapshotSynchronizationCase2() throws Exception {
QuorumTestCallback cb1 = new QuorumTestCallback();
QuorumTestCallback cb2 = new QuorumTestCallback();
SnapshotStateMachine st1 = new SnapshotStateMachine(0);
SnapshotStateMachine st2 = new SnapshotStateMachine(0);
String server1 = getUniqueHostPort();
String server2 = getUniqueHostPort();
String server3 = getUniqueHostPort();
Set<String> peers = new HashSet<>();
peers.add(server1);
peers.add(server2);
peers.add(server3);
PersistentState state1 = makeInitialState(server1, 5);
state1.setAckEpoch(0);
PersistentState state2 = makeInitialState(server2, 1);
state2.setAckEpoch(0);
ZabConfig config1 = new ZabConfig();
ZabConfig config2 = new ZabConfig();
Zab zab1 = new Zab(st1, config1, server1, peers, state1, cb1, null);
Zab zab2 = new Zab(st2, config2, server2, peers, state2, cb2, null);
cb1.waitBroadcasting();
cb2.waitBroadcasting();
Assert.assertEquals(cb1.initialHistory.size(), 5);
Assert.assertEquals(cb1.initialHistory.get(0).getZxid(), new Zxid(0, 0));
Assert.assertTrue(cb2.initialHistory.size() == 5);
Assert.assertEquals(cb2.initialHistory.get(0).getZxid(), new Zxid(0, 0));
// server1 gonna take snapshot.
zab1.takeSnapshot(null);
st1.waitSnapshot();
// Make sure server1 does take snapshot.
Assert.assertEquals(new Zxid(0, 4), state1.getSnapshotZxid());
// Shutdown zab2.
zab2.shutdown();
// Mannuly truncate all the logs of server2.
state2.getLog().truncate(Zxid.ZXID_NOT_EXIST);
st2 = new SnapshotStateMachine(0);
cb2 = new QuorumTestCallback();
// Restarts server2.
zab2 = new Zab(st2, config2, state2, cb2, null);
cb2.waitBroadcasting();
// server2 should get snapshot file synchronized from server1.
Assert.assertEquals(new Zxid(0, 4), state2.getSnapshotZxid());
// Eventually they will have same state.
Assert.assertEquals(st2.state, st1.state);
zab1.shutdown();
zab2.shutdown();
}
@Test(timeout=20000)
public void testSnapshotSynchronizationCase3() throws Exception {
QuorumTestCallback cb1 = new QuorumTestCallback();
QuorumTestCallback cb2 = new QuorumTestCallback();
SnapshotStateMachine st1 = new SnapshotStateMachine(0);
SnapshotStateMachine st2 = new SnapshotStateMachine(0);
String server1 = getUniqueHostPort();
String server2 = getUniqueHostPort();
String server3 = getUniqueHostPort();
Set<String> peers = new HashSet<>();
peers.add(server1);
peers.add(server2);
peers.add(server3);
PersistentState state1 = makeInitialState(server1, 5);
state1.setProposedEpoch(2);
state1.setAckEpoch(2);
PersistentState state2 = makeInitialState(server2, 1);
state2.setAckEpoch(0);
ZabConfig config1 = new ZabConfig();
ZabConfig config2 = new ZabConfig();
Zab zab1 = new Zab(st1, config1, server1, peers, state1, cb1, null);
Zab zab2 = new Zab(st2, config2, server2, peers, state2, cb2, null);
cb1.waitBroadcasting();
cb2.waitBroadcasting();
Assert.assertEquals(cb1.initialHistory.size(), 5);
Assert.assertEquals(cb1.initialHistory.get(0).getZxid(), new Zxid(0, 0));
Assert.assertTrue(cb2.initialHistory.size() == 5);
Assert.assertEquals(cb2.initialHistory.get(0).getZxid(), new Zxid(0, 0));
// server1 takes snapshot.
zab1.takeSnapshot(null);
st1.waitSnapshot();
// Make sure server1 did take snapshot.
Assert.assertEquals(new Zxid(0, 4), state1.getSnapshotZxid());
// Shutdown zab2.
zab2.shutdown();
// Add one more transaction to make server1 and server2 have different
// epochs.
appendTxns(state2.getLog(), new Zxid(1, 0), 1);
// Reset epoch number to make sure server1 becomes leader.
state2.setProposedEpoch(0);
state2.setAckEpoch(0);
st2 = new SnapshotStateMachine(0);
cb2 = new QuorumTestCallback();
// Restarts server2.
zab2 = new Zab(st2, config2, state2, cb2, null);
cb2.waitBroadcasting();
// server2 should get snapshot file synchronized from server1.
Assert.assertEquals(new Zxid(0, 4), state2.getSnapshotZxid());
// Eventuall they will have same state.
Assert.assertEquals(st2.state, st1.state);
zab1.shutdown();
zab2.shutdown();
}
}