package org.infinispan.commands.remote;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import org.infinispan.commands.CommandsFactory;
import org.infinispan.commands.control.LockControlCommand;
import org.infinispan.commands.read.GetAllCommand;
import org.infinispan.commons.marshall.MarshallUtil;
import org.infinispan.commons.util.EnumUtil;
import org.infinispan.container.InternalEntryFactory;
import org.infinispan.container.entries.CacheEntry;
import org.infinispan.container.entries.InternalCacheEntry;
import org.infinispan.container.entries.InternalCacheValue;
import org.infinispan.context.InvocationContext;
import org.infinispan.context.InvocationContextFactory;
import org.infinispan.context.impl.FlagBitSets;
import org.infinispan.interceptors.AsyncInterceptorChain;
import org.infinispan.remoting.responses.Response;
import org.infinispan.transaction.impl.TransactionTable;
import org.infinispan.transaction.xa.GlobalTransaction;
import org.infinispan.util.ByteString;
import org.infinispan.util.logging.Log;
import org.infinispan.util.logging.LogFactory;
/**
* Issues a remote getAll call. This is not a {@link org.infinispan.commands.VisitableCommand} and hence not passed up the
* interceptor chain.
*
* @author Radim Vansa <rvansa@redhat.com>
*/
public class ClusteredGetAllCommand<K, V> extends BaseClusteredReadCommand {
public static final byte COMMAND_ID = 46;
private static final Log log = LogFactory.getLog(ClusteredGetAllCommand.class);
private static final boolean trace = log.isTraceEnabled();
private List<?> keys;
private GlobalTransaction gtx;
private InvocationContextFactory icf;
private CommandsFactory commandsFactory;
private AsyncInterceptorChain invoker;
private TransactionTable txTable;
private InternalEntryFactory entryFactory;
ClusteredGetAllCommand() {
super(null, EnumUtil.EMPTY_BIT_SET);
}
public ClusteredGetAllCommand(ByteString cacheName) {
super(cacheName, EnumUtil.EMPTY_BIT_SET);
}
public ClusteredGetAllCommand(ByteString cacheName, List<?> keys, long flags, GlobalTransaction gtx) {
super(cacheName, flags);
this.keys = keys;
this.gtx = gtx;
}
public void init(InvocationContextFactory icf, CommandsFactory commandsFactory, InternalEntryFactory entryFactory,
AsyncInterceptorChain interceptorChain, TransactionTable txTable) {
this.icf = icf;
this.commandsFactory = commandsFactory;
this.invoker = interceptorChain;
this.txTable = txTable;
this.entryFactory = entryFactory;
}
@Override
public CompletableFuture<Object> invokeAsync() throws Throwable {
if (!hasAnyFlag(FlagBitSets.FORCE_WRITE_LOCK)) {
return invokeGetAll();
} else {
return acquireLocks().thenCompose(o -> invokeGetAll());
}
}
private CompletableFuture<Object> invokeGetAll() {
// make sure the get command doesn't perform a remote call
// as our caller is already calling the ClusteredGetCommand on all the relevant nodes
GetAllCommand command = commandsFactory.buildGetAllCommand(keys, getFlagsBitSet(), true);
InvocationContext invocationContext = icf.createRemoteInvocationContextForCommand(command, getOrigin());
CompletableFuture<Object> future = invoker.invokeAsync(invocationContext, command);
return future.thenApply(rv -> {
if (trace) log.trace("Found: " + rv);
if (rv == null || rv instanceof Response) {
return rv;
}
Map<K, CacheEntry<K, V>> map = (Map<K, CacheEntry<K, V>>) rv;
InternalCacheValue<V>[] values = new InternalCacheValue[keys.size()];
int i = 0;
for (Object key : keys) {
CacheEntry<K, V> entry = map.get(key);
InternalCacheValue<V> value;
if (entry == null) {
value = null;
} else if (entry instanceof InternalCacheEntry) {
value = ((InternalCacheEntry<K, V>) entry).toInternalCacheValue();
} else {
value = entryFactory.createValue(entry);
}
values[i++] = value;
}
return values;
});
}
private CompletableFuture<Object> acquireLocks() throws Throwable {
LockControlCommand lockControlCommand = commandsFactory.buildLockControlCommand(keys, getFlagsBitSet(), gtx);
lockControlCommand.init(invoker, icf, txTable);
return lockControlCommand.invokeAsync();
}
public List<?> getKeys() {
return keys;
}
@Override
public byte getCommandId() {
return COMMAND_ID;
}
@Override
public void writeTo(ObjectOutput output) throws IOException {
MarshallUtil.marshallCollection(keys, output);
output.writeLong(FlagBitSets.copyWithoutRemotableFlags(getFlagsBitSet()));
output.writeObject(gtx);
}
@Override
public void readFrom(ObjectInput input) throws IOException, ClassNotFoundException {
keys = MarshallUtil.unmarshallCollection(input, ArrayList::new);
setFlagsBitSet(input.readLong());
gtx = (GlobalTransaction) input.readObject();
}
@Override
public boolean isReturnValueExpected() {
return true;
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("ClusteredGetAllCommand{");
sb.append("keys=").append(keys);
sb.append(", flags=").append(printFlags());
sb.append('}');
return sb.toString();
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
ClusteredGetAllCommand<?, ?> other = (ClusteredGetAllCommand<?, ?>) obj;
if (gtx == null) {
if (other.gtx != null)
return false;
} else if (!gtx.equals(other.gtx))
return false;
if (keys == null) {
if (other.keys != null)
return false;
} else if (!keys.equals(other.keys))
return false;
return true;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((gtx == null) ? 0 : gtx.hashCode());
result = prime * result + ((keys == null) ? 0 : keys.hashCode());
return result;
}
}