package org.infinispan.partitionhandling;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.fail;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.infinispan.Cache;
import org.infinispan.configuration.cache.CacheMode;
import org.infinispan.configuration.cache.ConfigurationBuilder;
import org.infinispan.notifications.Listener;
import org.infinispan.notifications.cachemanagerlistener.annotation.ViewChanged;
import org.infinispan.notifications.cachemanagerlistener.event.ViewChangedEvent;
import org.infinispan.partitionhandling.impl.PartitionHandlingManager;
import org.infinispan.remoting.transport.AbstractDelegatingTransport;
import org.infinispan.remoting.transport.Transport;
import org.infinispan.remoting.transport.jgroups.JGroupsTransport;
import org.infinispan.test.MultipleCacheManagersTest;
import org.infinispan.test.TestingUtil;
import org.infinispan.test.fwk.TEST_PING;
import org.infinispan.test.fwk.TransportFlags;
import org.infinispan.util.logging.Log;
import org.infinispan.util.logging.LogFactory;
import org.jgroups.Address;
import org.jgroups.JChannel;
import org.jgroups.MergeView;
import org.jgroups.View;
import org.jgroups.protocols.DISCARD;
import org.jgroups.protocols.Discovery;
import org.jgroups.protocols.TP;
import org.jgroups.protocols.pbcast.GMS;
import org.jgroups.stack.Protocol;
import org.jgroups.stack.ProtocolStack;
import org.testng.annotations.Test;
@Test(groups = "functional", testName = "partitionhandling.BasePartitionHandlingTest")
public class BasePartitionHandlingTest extends MultipleCacheManagersTest {
private static Log log = LogFactory.getLog(BasePartitionHandlingTest.class);
private final AtomicInteger viewId = new AtomicInteger(5);
protected int numMembersInCluster = 4;
protected CacheMode cacheMode = CacheMode.DIST_SYNC;
protected volatile Partition[] partitions;
protected boolean partitionHandling = true;
public BasePartitionHandlingTest() {
this.cleanup = CleanupPhase.AFTER_METHOD;
}
@Override
protected void createCacheManagers() throws Throwable {
ConfigurationBuilder dcc = cacheConfiguration();
dcc.clustering().cacheMode(cacheMode).partitionHandling().enabled(partitionHandling);
createClusteredCaches(numMembersInCluster, dcc, new TransportFlags().withFD(true).withMerge(true));
waitForClusterToForm();
}
protected ConfigurationBuilder cacheConfiguration() {
return new ConfigurationBuilder();
}
@Listener
static class ViewChangedHandler {
volatile boolean notified = false;
@ViewChanged
public void viewChanged(ViewChangedEvent vce) {
notified = true;
}
}
public static class PartitionDescriptor {
int[] nodes;
public PartitionDescriptor(int... nodes) {
this.nodes = nodes;
}
public int[] getNodes() {
return nodes;
}
public int node(int i) {
return nodes[i];
}
}
public class Partition {
private final List<Address> allMembers;
List<JChannel> channels = new ArrayList<>();
public Partition(List<Address> allMembers) {
this.allMembers = allMembers;
}
public void addNode(JChannel c) {
channels.add(c);
}
public void partition() {
discardOtherMembers();
log.trace("Partition forming");
disableDiscovery();
installNewView();
assertPartitionFormed();
log.trace("New views installed");
}
private void disableDiscovery() {
for (JChannel c : channels) {
for (Protocol p : c.getProtocolStack().getProtocols()) {
if (p instanceof Discovery) {
if (!(p instanceof TEST_PING)) throw new IllegalStateException("TEST_PING required for this test.");
((TEST_PING) p).suspend();
}
}
}
}
private void assertPartitionFormed() {
final List<Address> viewMembers = new ArrayList<>();
for (JChannel ac : channels) viewMembers.add(ac.getAddress());
for (JChannel c : channels) {
List<Address> members = c.getView().getMembers();
if (!members.equals(viewMembers)) throw new AssertionError();
}
}
private List<Address> installNewView() {
final List<Address> viewMembers = new ArrayList<>();
for (JChannel c : channels) viewMembers.add(c.getAddress());
View view = View.create(channels.get(0).getAddress(), viewId.incrementAndGet(), (Address[]) viewMembers.toArray(new Address[viewMembers.size()]));
log.trace("Before installing new view...");
for (JChannel c : channels)
((GMS) c.getProtocolStack().findProtocol(GMS.class)).installView(view);
return viewMembers;
}
private List<Address> installMergeView(ArrayList<JChannel> view1, ArrayList<JChannel> view2) {
List<Address> allAddresses =
Stream.concat(view1.stream(), view2.stream()).map(JChannel::getAddress).distinct()
.collect(Collectors.toList());
View v1 = toView(view1);
View v2 = toView(view2);
List<View> allViews = new ArrayList<>();
allViews.add(v1);
allViews.add(v2);
// log.trace("Before installing new view: " + viewMembers);
// System.out.println("Before installing new view: " + viewMembers);
MergeView mv = new MergeView(view1.get(0).getAddress(), (long)viewId.incrementAndGet(), allAddresses, allViews);
for (JChannel c : channels)
((GMS) c.getProtocolStack().findProtocol(GMS.class)).installView(mv);
return allMembers;
}
private View toView(ArrayList<JChannel> channels) {
final List<Address> viewMembers = new ArrayList<>();
for (JChannel c : channels) viewMembers.add(c.getAddress());
return View.create(channels.get(0).getAddress(), viewId.incrementAndGet(), (Address[]) viewMembers.toArray(new Address[viewMembers.size()]));
}
private void discardOtherMembers() {
List<Address> outsideMembers = new ArrayList<>();
for (Address a : allMembers) {
boolean inThisPartition = false;
for (JChannel c : channels) {
if (c.getAddress().equals(a)) inThisPartition = true;
}
if (!inThisPartition) outsideMembers.add(a);
}
for (JChannel c : channels) {
DISCARD discard = new DISCARD();
for (Address a : outsideMembers) discard.addIgnoreMember(a);
try {
c.getProtocolStack().insertProtocol(discard, ProtocolStack.Position.ABOVE, TP.class);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
@Override
public String toString() {
String addresses = "";
for (JChannel c : channels) addresses += c.getAddress() + " ";
return "Partition{" + addresses + '}';
}
public void merge(Partition partition) {
observeMembers(partition);
partition.observeMembers(this);
ArrayList<JChannel> view1 = new ArrayList<>(channels);
ArrayList<JChannel> view2 = new ArrayList<>(partition.channels);
// System.out.println("view1 = " + printView(view1));
// System.out.println("view2 = " + printView(view2));
partition.channels.stream().filter(c -> !channels.contains(c)).forEach(c -> channels.add(c));
installMergeView(view1, view2);
waitForPartitionToForm();
List<Partition> tmp = new ArrayList<>(Arrays.asList(BasePartitionHandlingTest.this.partitions));
if (!tmp.remove(partition)) throw new AssertionError();
BasePartitionHandlingTest.this.partitions = tmp.toArray(new Partition[tmp.size()]);
}
private String printView(ArrayList<JChannel> view1) {
StringBuilder sb = new StringBuilder();
for (JChannel c: view1) sb.append(c.getAddress()).append(" ");
return sb.insert(0, "[ ").append(" ]").toString();
}
private void waitForPartitionToForm() {
List<Cache<Object, Object>> caches = new ArrayList<>(getCaches(null));
Iterator<Cache<Object, Object>> i = caches.iterator();
while (i.hasNext()) {
if (!channels.contains(channel(i.next())))
i.remove();
}
Cache<Object, Object> cache = caches.get(0);
TestingUtil.blockUntilViewsReceived(10000, caches);
if (cache.getCacheConfiguration().clustering().cacheMode().isClustered()) {
TestingUtil.waitForNoRebalance(caches);
}
}
public void enableDiscovery() {
for (JChannel c : channels) {
for (Protocol p : c.getProtocolStack().getProtocols()) {
if (p instanceof Discovery) {
try {
log.tracef("About to start discovery: %s", p);
p.start();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}
log.trace("Discovery started.");
}
private void observeMembers(Partition partition) {
for (JChannel c : channels) {
List<Protocol> protocols = c.getProtocolStack().getProtocols();
for (Protocol p : protocols) {
if (p instanceof DISCARD) {
for (JChannel oc : partition.channels) {
((DISCARD) p).removeIgnoredMember(oc.getAddress());
}
}
}
}
}
public void assertDegradedMode() {
if (partitionHandling) {
assertAvailabilityMode(AvailabilityMode.DEGRADED_MODE);
}
}
public void assertKeyAvailableForRead(Object k, Object expectedValue) {
for (Cache c : cachesInThisPartition()) {
assertEquals(c.get(k), expectedValue, "Cache " + c.getAdvancedCache().getRpcManager().getAddress() + " doesn't see the right value: ");
}
}
public void assertKeyAvailableForWrite(Object k, Object newValue) {
for (Cache<Object, Object> c : cachesInThisPartition()) {
c.put(k, newValue);
assertEquals(c.get(k), newValue, "Cache " + c.getAdvancedCache().getRpcManager().getAddress() + " doesn't see the right value");
}
}
protected void assertKeysNotAvailableForRead(Object... keys) {
for (Object k : keys)
assertKeyNotAvailableForRead(k);
}
protected void assertKeyNotAvailableForRead(Object key) {
for (Cache<Object, ?> c : cachesInThisPartition()) {
try {
c.get(key);
fail("Key " + key + " available in cache " + address(c));
} catch (AvailabilityException ae) {
//expected!
}
}
}
private <K,V> List<Cache<K,V>> cachesInThisPartition() {
List<Cache<K,V>> caches = new ArrayList<>();
for (final Cache<K,V> c : BasePartitionHandlingTest.this.<K,V>caches()) {
if (channels.contains(channel(c))) {
caches.add(c);
}
}
return caches;
}
public void assertKeyNotAvailableForWrite(Object key) {
for (Cache<Object, Object> c : cachesInThisPartition()) {
try {
c.put(key, key);
fail();
} catch (AvailabilityException ae) {
//expected!
}
}
}
public void assertKeysNotAvailableForWrite(Object ... keys) {
for (Object k : keys) {
assertKeyNotAvailableForWrite(k);
}
}
public void assertAvailabilityMode(final AvailabilityMode state) {
for (final Cache c : cachesInThisPartition()) {
eventuallyEquals(state, () -> partitionHandlingManager(c).getAvailabilityMode());
}
}
}
protected void splitCluster(int[]... parts) {
List<Address> allMembers = channel(0).getView().getMembers();
partitions = new Partition[parts.length];
for (int i = 0; i < parts.length; i++) {
Partition p = new Partition(allMembers);
for (int j : parts[i]) {
p.addNode(channel(j));
}
partitions[i] = p;
p.partition();
}
}
protected void isolatePartition(int[] isolatedPartition) {
List<Address> allMembers = channel(0).getView().getMembers();
Partition p0 = new Partition(allMembers);
IntStream.range(0, allMembers.size()).forEach(i -> p0.addNode(channel(i)));
Partition p1 = new Partition(allMembers);
Arrays.stream(isolatedPartition).forEach(i -> p1.addNode(channel(i)));
p1.partition();
partitions = new Partition[]{p0, p1};
}
private JChannel channel(int i) {
return channel(cache(i));
}
private JChannel channel(Cache<?, ?> cache) {
return extractJGroupsTransport(cache.getAdvancedCache().getRpcManager().getTransport()).getChannel();
}
protected Partition partition(int i) {
if (partitions == null)
throw new IllegalStateException("splitCluster(..) must be invoked before this method!");
return partitions[i];
}
protected PartitionHandlingManager partitionHandlingManager(int index) {
return partitionHandlingManager(advancedCache(index));
}
protected PartitionHandlingManager partitionHandlingManager(Cache cache) {
return cache.getAdvancedCache().getComponentRegistry().getComponent(PartitionHandlingManager.class);
}
protected void assertExpectedValue(Object expectedVal, Object key) {
for (int i = 0; i < numMembersInCluster; i++) {
assertEquals(cache(i).get(key), expectedVal);
}
}
private static JGroupsTransport extractJGroupsTransport(Transport transport) {
if (transport instanceof AbstractDelegatingTransport) {
return extractJGroupsTransport(((AbstractDelegatingTransport) transport).getDelegate());
} else if (transport instanceof JGroupsTransport) {
return (JGroupsTransport) transport;
}
throw new IllegalArgumentException("Transport is not a JGroupsTransport! It is " + transport.getClass());
}
}