package org.radargun.service; import java.util.List; import java.util.Set; import org.infinispan.remoting.transport.Transport; import org.infinispan.remoting.transport.jgroups.JGroupsTransport; import org.jgroups.JChannel; import org.jgroups.protocols.TP; import org.jgroups.stack.ProtocolStack; import org.radargun.protocols.SLAVE_PARTITION; import org.radargun.traits.Partitionable; public class InfinispanPartitionableLifecycle extends InfinispanKillableLifecycle implements Partitionable { private int mySlaveIndex = -1; private Set<Integer> initiallyReachable; public InfinispanPartitionableLifecycle(Infinispan51EmbeddedService wrapper) { super(wrapper); } protected Class<? extends SLAVE_PARTITION> getPartitionProtocolClass() { return SLAVE_PARTITION.class; } @Override public void setMembersInPartition(int slaveIndex, Set<Integer> members) { List<JChannel> channels = getChannels(null); log.trace("Found " + channels.size() + " channels"); for (JChannel channel : channels) { setPartitionInChannel(channel, slaveIndex, members); } } private void setPartitionInChannel(JChannel channel, int slaveIndex, Set<Integer> members) { log.trace("Setting partition in channel " + channel); SLAVE_PARTITION partition = (SLAVE_PARTITION) channel.getProtocolStack().findProtocol(getPartitionProtocolClass()); if (partition == null) { log.info("No SLAVE_PARTITION protocol found in stack for " + channel.getName() + ", inserting above transport protocol"); try { partition = getPartitionProtocolClass().newInstance(); } catch (Exception e) { log.error("Error creating SLAVE_PARTITION protocol", e); return; } try { channel.getProtocolStack().insertProtocol(partition, ProtocolStack.ABOVE, TP.class); } catch (Exception e) { log.error("Error inserting the SLAVE_PARTITION protocol to stack for " + channel.getName()); return; } } partition.setSlaveIndex(slaveIndex); partition.setAllowedSlaves(members); log.trace("Finished setting partition in channel " + channel); } @Override public void setStartWithReachable(int slaveIndex, Set<Integer> members) { mySlaveIndex = slaveIndex; initiallyReachable = members; } public Transport createTransport() { return new HookedJGroupsTransport(); } private class HookedJGroupsTransport extends JGroupsTransport { /** * This is called after the channel is initialized but before it is connected */ @Override protected void startJGroupsChannelIfNeeded() { log.trace("My index is " + mySlaveIndex + " and these slaves should be reachable: " + initiallyReachable); if (mySlaveIndex >= 0 && initiallyReachable != null) { List<JChannel> channels = getChannels((JChannel) this.channel); log.trace("Found " + channels.size() + " channels"); for (JChannel channel : channels) { setPartitionInChannel(channel, mySlaveIndex, initiallyReachable); } } super.startJGroupsChannelIfNeeded(); } } }