package org.infinispan.server.hotrod;
import static org.infinispan.server.hotrod.ResponseWriting.writeResponse;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.BitSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executor;
import javax.security.auth.Subject;
import org.infinispan.commons.logging.LogFactory;
import org.infinispan.commons.marshall.Marshaller;
import org.infinispan.commons.marshall.jboss.GenericJBossMarshaller;
import org.infinispan.security.Security;
import org.infinispan.server.core.transport.NettyTransport;
import org.infinispan.server.hotrod.iteration.IterableIterationResult;
import org.infinispan.server.hotrod.logging.Log;
import org.infinispan.server.hotrod.util.BulkUtil;
import org.infinispan.tasks.TaskContext;
import org.infinispan.tasks.TaskManager;
import org.infinispan.util.KeyValuePair;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
/**
* Handler that performs actual cache operations. Note this handler should be on a separate executor group than the
* decoder.
*
* @author wburns
* @since 9.0
*/
public class ContextHandler extends SimpleChannelInboundHandler<CacheDecodeContext> {
private final static Log log = LogFactory.getLog(ContextHandler.class, Log.class);
private final HotRodServer server;
private final NettyTransport transport;
private final Executor executor;
public ContextHandler(HotRodServer server, NettyTransport transport, Executor executor) {
this.server = server;
this.transport = transport;
this.executor = executor;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, CacheDecodeContext msg) throws Exception {
executor.execute(() -> {
try {
Subject subject = msg.subject;
if (subject == null)
realRead(ctx, msg);
else Security.doAs(subject, (PrivilegedExceptionAction<Void>) () -> {
realRead(ctx, msg);
return null;
});
} catch (PrivilegedActionException e) {
ctx.fireExceptionCaught(e.getCause());
} catch (Exception e) {
ctx.fireExceptionCaught(e);
}
});
}
protected void realRead(ChannelHandlerContext ctx, CacheDecodeContext msg) throws Exception {
HotRodHeader h = msg.header;
switch (h.op) {
case PUT:
writeResponse(msg, ctx.channel(), msg.put());
break;
case PUT_IF_ABSENT:
writeResponse(msg, ctx.channel(), msg.putIfAbsent());
break;
case REPLACE:
writeResponse(msg, ctx.channel(), msg.replace());
break;
case REPLACE_IF_UNMODIFIED:
writeResponse(msg, ctx.channel(), msg.replaceIfUnmodified());
break;
case CONTAINS_KEY:
writeResponse(msg, ctx.channel(), msg.containsKey());
break;
case GET:
case GET_WITH_VERSION:
writeResponse(msg, ctx.channel(), msg.get());
break;
case GET_STREAM:
case GET_WITH_METADATA:
writeResponse(msg, ctx.channel(), msg.getKeyMetadata());
break;
case REMOVE:
writeResponse(msg, ctx.channel(), msg.remove());
break;
case REMOVE_IF_UNMODIFIED:
writeResponse(msg, ctx.channel(), msg.removeIfUnmodified());
break;
case PING:
writeResponse(msg, ctx.channel(), new EmptyResponse(h.version, h.messageId, h.cacheName,
h.clientIntel, HotRodOperation.PING, OperationStatus.Success, h.topologyId));
break;
case STATS:
writeResponse(msg, ctx.channel(), msg.decoder.createStatsResponse(msg, transport));
break;
case CLEAR:
writeResponse(msg, ctx.channel(), msg.clear());
break;
case SIZE:
writeResponse(msg, ctx.channel(), new SizeResponse(h.version, h.messageId, h.cacheName,
h.clientIntel, h.topologyId, msg.cache.size()));
break;
case EXEC:
ExecRequestContext execContext = (ExecRequestContext) msg.operationDecodeContext;
TaskManager taskManager = SecurityActions.getCacheGlobalComponentRegistry(msg.cache).getComponent(TaskManager.class);
Marshaller marshaller;
if (server.getMarshaller() != null) {
marshaller = server.getMarshaller();
} else {
marshaller = new GenericJBossMarshaller();
}
byte[] result = (byte[]) taskManager.runTask(execContext.getName(),
new TaskContext().marshaller(marshaller).cache(msg.cache).parameters(execContext.getParams())).get();
writeResponse(msg, ctx.channel(),
new ExecResponse(h.version, h.messageId, h.cacheName, h.clientIntel, h.topologyId,
result == null ? new byte[]{} : result));
break;
case BULK_GET:
int size = (int) msg.operationDecodeContext;
if (CacheDecodeContext.isTrace) {
log.tracef("About to create bulk response count = %d", size);
}
writeResponse(msg, ctx.channel(), new BulkGetResponse(h.version, h.messageId, h.cacheName, h.clientIntel,
h.topologyId, size, msg.cache.entrySet()));
break;
case BULK_GET_KEYS:
int scope = (int) msg.operationDecodeContext;
if (CacheDecodeContext.isTrace) {
log.tracef("About to create bulk get keys response scope = %d", scope);
}
writeResponse(msg, ctx.channel(), new BulkGetKeysResponse(h.version, h.messageId, h.cacheName, h.clientIntel,
h.topologyId, scope, BulkUtil.getAllKeys(msg.cache, scope)));
break;
case QUERY:
byte[] queryResult = server.query(msg.cache, (byte[]) msg.operationDecodeContext);
writeResponse(msg, ctx.channel(),
new QueryResponse(h.version, h.messageId, h.cacheName, h.clientIntel, h.topologyId, queryResult));
break;
case ADD_CLIENT_LISTENER:
ClientListenerRequestContext clientContext = (ClientListenerRequestContext) msg.operationDecodeContext;
server.getClientListenerRegistry().addClientListener(msg.decoder, ctx.channel(), h, clientContext.getListenerId(),
msg.cache, clientContext.isIncludeCurrentState(), new KeyValuePair<>(clientContext.getFilterFactoryInfo(),
clientContext.getConverterFactoryInfo()), clientContext.isUseRawData(), clientContext.getListenerInterests());
break;
case REMOVE_CLIENT_LISTENER:
byte[] listenerId = (byte[]) msg.operationDecodeContext;
if (server.getClientListenerRegistry().removeClientListener(listenerId, msg.cache)) {
writeResponse(msg, ctx.channel(), msg.decoder.createSuccessResponse(h, null));
} else {
writeResponse(msg, ctx.channel(), msg.decoder.createNotExecutedResponse(h, null));
}
break;
case ITERATION_START:
IterationStartRequest iterationStart = (IterationStartRequest) msg.operationDecodeContext;
Optional<BitSet> optionBitSet;
if (iterationStart.getOptionBitSet().isPresent()) {
optionBitSet = Optional.of(BitSet.valueOf(iterationStart.getOptionBitSet().get()));
} else {
optionBitSet = Optional.empty();
}
String iterationId = server.getIterationManager().start(msg.cache.getName(), optionBitSet,
iterationStart.getFactory(), iterationStart.getBatch(), iterationStart.isMetadata());
writeResponse(msg, ctx.channel(), new IterationStartResponse(h.version, h.messageId, h.cacheName,
h.clientIntel, h.topologyId, iterationId));
break;
case ITERATION_NEXT:
iterationId = (String) msg.operationDecodeContext;
IterableIterationResult iterationResult = server.getIterationManager().next(msg.cache.getName(), iterationId);
writeResponse(msg, ctx.channel(), new IterationNextResponse(h.version, h.messageId, h.cacheName,
h.clientIntel, h.topologyId, iterationResult));
break;
case ITERATION_END:
iterationId = (String) msg.operationDecodeContext;
boolean removed = server.getIterationManager().close(msg.cache.getName(), iterationId);
writeResponse(msg, ctx.channel(), new EmptyResponse(h.version, h.messageId, h.cacheName, h.clientIntel,
HotRodOperation.ITERATION_END,
removed ? OperationStatus.Success : OperationStatus.InvalidIteration, h.topologyId));
break;
case PUT_ALL:
msg.cache.putAll((Map<byte[], byte[]>) msg.operationDecodeContext, msg.buildMetadata());
writeResponse(msg, ctx.channel(), msg.decoder.createSuccessResponse(h, null));
break;
case GET_ALL:
Map<byte[], byte[]> map = msg.cache.getAll((Set<byte[]>) msg.operationDecodeContext);
writeResponse(msg, ctx.channel(), new GetAllResponse(h.version, h.messageId, h.cacheName,
h.clientIntel, h.topologyId, map));
break;
case PUT_STREAM:
ByteBuf buf = (ByteBuf) msg.operationDecodeContext;
try {
byte[] bytes = new byte[buf.readableBytes()];
buf.readBytes(bytes);
msg.operationDecodeContext = bytes;
long version = msg.params.streamVersion;
if (version == 0) { // Normal put
writeResponse(msg, ctx.channel(), msg.put());
} else if (version < 0) { // putIfAbsent
writeResponse(msg, ctx.channel(), msg.putIfAbsent());
} else { // versioned replace
writeResponse(msg, ctx.channel(), msg.replaceIfUnmodified());
}
} finally {
buf.release();
}
break;
default:
throw new IllegalArgumentException("Unsupported operation invoked: " + msg.header.op);
}
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
log.tracef("Channel %s became active", ctx.channel());
server.getClientListenerRegistry().findAndWriteEvents(ctx.channel());
}
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
super.channelWritabilityChanged(ctx);
log.tracef("Channel %s writability changed", ctx.channel());
server.getClientListenerRegistry().findAndWriteEvents(ctx.channel());
}
@Override
public boolean acceptInboundMessage(Object msg) throws Exception {
// Faster than netty matcher
return msg.getClass() == CacheDecodeContext.class;
}
}