package org.infinispan.server.hotrod.test;
import static org.infinispan.server.hotrod.OperationStatus.NotExecutedWithPrevious;
import static org.infinispan.server.hotrod.OperationStatus.Success;
import static org.infinispan.server.hotrod.OperationStatus.SuccessWithPrevious;
import static org.infinispan.server.hotrod.transport.ExtendedByteBuf.readString;
import static org.infinispan.server.hotrod.transport.ExtendedByteBuf.readUnsignedInt;
import static org.infinispan.server.hotrod.transport.ExtendedByteBuf.readUnsignedShort;
import static org.infinispan.server.hotrod.transport.ExtendedByteBuf.writeRangedBytes;
import static org.infinispan.server.hotrod.transport.ExtendedByteBuf.writeString;
import static org.infinispan.server.hotrod.transport.ExtendedByteBuf.writeUnsignedInt;
import static org.infinispan.server.hotrod.transport.ExtendedByteBuf.writeUnsignedLong;
import static org.testng.AssertJUnit.assertFalse;
import static org.testng.AssertJUnit.assertTrue;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import javax.net.ssl.SSLEngine;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import org.infinispan.commons.logging.LogFactory;
import org.infinispan.commons.marshall.WrappedByteArray;
import org.infinispan.commons.util.Util;
import org.infinispan.server.core.transport.NettyInitializer;
import org.infinispan.server.core.transport.NettyInitializers;
import org.infinispan.server.hotrod.Constants;
import org.infinispan.server.hotrod.HotRodOperation;
import org.infinispan.server.hotrod.OperationStatus;
import org.infinispan.server.hotrod.ProtocolFlag;
import org.infinispan.server.hotrod.Response;
import org.infinispan.server.hotrod.ServerAddress;
import org.infinispan.server.hotrod.logging.Log;
import org.infinispan.server.hotrod.transport.ExtendedByteBuf;
import org.infinispan.test.TestingUtil;
import org.infinispan.test.fwk.TestResourceTracker;
import org.infinispan.util.KeyValuePair;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.MessageToByteEncoder;
import io.netty.handler.codec.ReplayingDecoder;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.DefaultThreadFactory;
import io.netty.util.concurrent.Future;
/**
* A very simple Hot Rod client for testing purposes. It's a quick and dirty client implementation. As a result, it
* might not be very readable, particularly for readers not used to scala.
* <p>
* Reasons why this should not really be a trait: Storing var instances in a trait cause issues with TestNG, see:
* http://thread.gmane.org/gmane.comp.lang.scala.user/24317
*
* @author Galder ZamarreƱo
* @author Tristan Tarrant
* @since 4.1
*/
public class HotRodClient {
private static final Log log = LogFactory.getLog(HotRodClient.class, Log.class);
final static AtomicLong idCounter = new AtomicLong();
final String host;
final int port;
final String defaultCacheName;
final int rspTimeoutSeconds;
final byte protocolVersion;
final SSLEngine sslEngine;
final Channel ch;
Map<Long, Op> idToOp = new ConcurrentHashMap<>();
private EventLoopGroup eventLoopGroup =
new NioEventLoopGroup(1, new DefaultThreadFactory(TestResourceTracker.getCurrentTestShortName() + "-Client"));
public HotRodClient(String host, int port, String defaultCacheName, int rspTimeoutSeconds, byte protocolVersion) {
this(host, port, defaultCacheName, rspTimeoutSeconds, protocolVersion, null);
}
public HotRodClient(String host, int port, String defaultCacheName, int rspTimeoutSeconds, byte protocolVersion,
SSLEngine sslEngine) {
this.host = host;
this.port = port;
this.defaultCacheName = defaultCacheName;
this.rspTimeoutSeconds = rspTimeoutSeconds;
this.protocolVersion = protocolVersion;
this.sslEngine = sslEngine;
ch = initializeChannel();
}
public String defaultCacheName() {
return defaultCacheName;
}
private Channel initializeChannel() {
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(eventLoopGroup);
bootstrap.handler(new NettyInitializers(new ClientChannelInitializer(this, rspTimeoutSeconds, sslEngine, protocolVersion)));
bootstrap.channel(NioSocketChannel.class);
bootstrap.option(ChannelOption.TCP_NODELAY, true);
bootstrap.option(ChannelOption.SO_KEEPALIVE, true);
// Make a new connection.
ChannelFuture connectFuture = bootstrap.connect(new InetSocketAddress(host, port));
// Wait until the connection is made successfully.
Channel ch = connectFuture.syncUninterruptibly().channel();
assertTrue(connectFuture.isSuccess());
return ch;
}
public Future<?> stop() {
return eventLoopGroup.shutdownGracefully(100, 1000, TimeUnit.MILLISECONDS);
}
public TestResponse put(byte[] k, int lifespan, int maxIdle, byte[] v) {
return execute(0xA0, (byte) 0x01, defaultCacheName, k, lifespan, maxIdle, v, 0, (byte) 1, 0);
}
public TestResponse put(byte[] k, int lifespan, int maxIdle, byte[] v, byte clientIntelligence, int topologyId) {
return execute(0xA0, (byte) 0x01, defaultCacheName, k, lifespan, maxIdle, v, 0, clientIntelligence, topologyId);
}
private boolean assertStatus(TestResponse resp, OperationStatus expected) {
OperationStatus status = resp.getStatus();
boolean isSuccess = status == expected;
if (resp instanceof TestErrorResponse) {
assertTrue(String.format("Status should have been '%s' but instead was: '%s', and the error message was: %s",
expected, status, ((TestErrorResponse) resp).msg), isSuccess);
} else {
assertTrue(String.format(
"Status should have been '%s' but instead was: '%s'", expected, status), isSuccess);
}
return isSuccess;
}
private byte[] k(Method m) {
return k(m, "k-");
}
private byte[] k(Method m, String prefix) {
byte[] bytes = (prefix + m.getName()).getBytes();
log.tracef("String %s is converted to %s bytes", prefix + m.getName(), Util.printArray(bytes, true));
return bytes;
}
private byte[] v(Method m) {
return v(m, "v-");
}
private byte[] v(Method m, String prefix) {
return k(m, prefix);
}
public void assertPut(Method m) {
assertStatus(put(k(m), 0, 0, v(m)), Success);
}
public void assertPutFail(Method m) {
Op op = new Op(0xA0, protocolVersion, (byte) 0x01, defaultCacheName, k(m), 0, 0, v(m), 0, 1, (byte) 0, 0);
idToOp.put(op.id, op);
ChannelFuture future = ch.writeAndFlush(op);
future.awaitUninterruptibly();
assertFalse(future.isSuccess());
}
public void assertPut(Method m, String kPrefix, String vPrefix) {
assertStatus(put(k(m, kPrefix), 0, 0, v(m, vPrefix)), Success);
}
public void assertPut(Method m, int lifespan, int maxIdle) {
assertStatus(put(k(m), lifespan, maxIdle, v(m)), Success);
}
public TestResponse put(String k, String v) {
return put(k.getBytes(), 0, 0, v.getBytes());
}
public TestResponse put(byte[] k, int lifespan, int maxIdle, byte[] v, int flags) {
return execute(0xA0, (byte) 0x01, defaultCacheName, k, lifespan, maxIdle, v, 0, flags);
}
public TestResponse putIfAbsent(byte[] k, int lifespan, int maxIdle, byte[] v) {
return execute(0xA0, (byte) 0x05, defaultCacheName, k, lifespan, maxIdle, v, 0, (byte) 1, 0);
}
public TestResponse putIfAbsent(byte[] k, int lifespan, int maxIdle, byte[] v, int flags) {
return execute(0xA0, (byte) 0x05, defaultCacheName, k, lifespan, maxIdle, v, 0, flags);
}
public TestResponse replace(byte[] k, int lifespan, int maxIdle, byte[] v) {
return execute(0xA0, (byte) 0x07, defaultCacheName, k, lifespan, maxIdle, v, 0, (byte) 1, 0);
}
public TestResponse replace(byte[] k, int lifespan, int maxIdle, byte[] v, int flags) {
return execute(0xA0, (byte) 0x07, defaultCacheName, k, lifespan, maxIdle, v, (byte) 0, flags);
}
public TestResponse replaceIfUnmodified(byte[] k, int lifespan, int maxIdle, byte[] v, long dataVersion) {
return execute(0xA0, (byte) 0x09, defaultCacheName, k, lifespan, maxIdle, v, dataVersion, (byte) 1, 0);
}
public TestResponse replaceIfUnmodified(byte[] k, int lifespan, int maxIdle, byte[] v, long dataVersion, int flags) {
return execute(0xA0, (byte) 0x09, defaultCacheName, k, lifespan, maxIdle, v, dataVersion, flags);
}
public TestResponse remove(byte[] k) {
return execute(0xA0, (byte) 0x0B, defaultCacheName, k, 0, 0, null, 0, (byte) 1, 0);
}
public TestResponse remove(byte[] k, int flags) {
return execute(0xA0, (byte) 0x0B, defaultCacheName, k, 0, 0, null, 0, flags);
}
public TestResponse removeIfUnmodified(byte[] k, int lifespan, int maxIdle, byte[] v, long dataVersion) {
return execute(0xA0, (byte) 0x0D, defaultCacheName, k, lifespan, maxIdle, v, dataVersion, (byte) 1, 0);
}
public TestResponse removeIfUnmodified(byte[] k, long dataVersion, int flags) {
return execute(0xA0, (byte) 0x0D, defaultCacheName, k, 0, 0, new byte[0], dataVersion, flags);
}
public TestResponse execute(int magic, byte code, String name, byte[] k, int lifespan, int maxIdle,
byte[] v, long dataVersion, byte clientIntelligence, int topologyId) {
Op op = new Op(magic, protocolVersion, code, name, k, lifespan, maxIdle, v, 0, dataVersion,
clientIntelligence, topologyId);
return execute(op, op.id);
}
public TestErrorResponse executeExpectBadMagic(int magic, byte code, String name, byte[] k, int lifespan, int maxIdle,
byte[] v, long version) {
Op op = new Op(magic, protocolVersion, code, name, k, lifespan, maxIdle, v, 0, version, (byte) 1, 0);
return (TestErrorResponse) execute(op, 0);
}
public TestErrorResponse executePartial(int magic, byte code, String name, byte[] k, int lifespan, int maxIdle,
byte[] v, long version) {
Op op = new PartialOp(magic, protocolVersion, code, name, k, lifespan, maxIdle, v, 0, version, (byte) 1, 0);
return (TestErrorResponse) execute(op, op.id);
}
public TestResponse execute(int magic, byte code, String name, byte[] k, int lifespan, int maxIdle,
byte[] v, long dataVersion, int flags) {
Op op = new Op(magic, protocolVersion, code, name, k, lifespan, maxIdle, v, flags, dataVersion, (byte) 1, 0);
return execute(op, op.id);
}
private TestResponse execute(Op op, long expectedResponseMessageId) {
writeOp(op);
ClientHandler handler = (ClientHandler) ch.pipeline().last();
return handler.getResponse(expectedResponseMessageId);
}
public boolean writeOp(Op op) {
return writeOp(op, true);
}
public boolean writeOp(Op op, boolean assertSuccess) {
idToOp.put(op.id, op);
ChannelFuture future = ch.writeAndFlush(op);
future.awaitUninterruptibly();
if (assertSuccess)
assertTrue(future.isSuccess());
return future.isSuccess();
}
public TestGetResponse get(byte[] k, int flags) {
return (TestGetResponse) get((byte) 0x03, k, flags);
}
public TestResponse get(String k) {
return get((byte) 0x03, k.getBytes(), 0);
}
public TestGetResponse assertGet(Method m) {
return assertGet(m, 0);
}
public TestGetResponse assertGet(Method m, int flags) {
return get(k(m), flags);
}
public TestResponse containsKey(byte[] k, int flags) {
return get((byte) 0x0F, k, flags);
}
public TestGetWithVersionResponse getWithVersion(byte[] k, int flags) {
return (TestGetWithVersionResponse) get((byte) 0x11, k, flags);
}
public TestGetWithMetadataResponse getWithMetadata(byte[] k, int flags) {
return (TestGetWithMetadataResponse) get((byte) 0x1B, k, flags);
}
private TestResponse get(byte code, byte[] k, int flags) {
Op op = new Op(0xA0, protocolVersion, code, defaultCacheName, k, 0, 0, null, flags, 0, (byte) 1, 0);
boolean writeFuture = writeOp(op);
// Get the handler instance to retrieve the answer.
ClientHandler handler = (ClientHandler) ch.pipeline().last();
if (code == 0x03 || code == 0x11 || code == 0x0F || code == 0x1B) {
return handler.getResponse(op.id);
} else {
return null;
}
}
public TestResponse clear() {
return execute(0xA0, (byte) 0x13, defaultCacheName, null, 0, 0, null, 0, (byte) 1, 0);
}
public Map<String, String> stats() {
StatsOp op = new StatsOp(0xA0, protocolVersion, (byte) 0x15, defaultCacheName, (byte) 1, 0, null);
boolean writeFuture = writeOp(op);
// Get the handler instance to retrieve the answer.
ClientHandler handler = (ClientHandler) ch.pipeline().last();
TestStatsResponse resp = (TestStatsResponse) handler.getResponse(op.id);
return resp.stats;
}
public TestResponse ping() {
return execute(0xA0, (byte) 0x17, defaultCacheName, null, 0, 0, null, 0, (byte) 1, 0);
}
public TestResponse ping(byte clientIntelligence, int topologyId) {
return execute(0xA0, (byte) 0x17, defaultCacheName, null, 0, 0, null, 0, clientIntelligence, topologyId);
}
public TestBulkGetResponse bulkGet() {
return bulkGet(0);
}
public TestBulkGetResponse bulkGet(int count) {
BulkGetOp op = new BulkGetOp(0xA0, protocolVersion, (byte) 0x19, defaultCacheName, (byte) 1, 0, count);
boolean writeFuture = writeOp(op);
// Get the handler instance to retrieve the answer.
ClientHandler handler = (ClientHandler) ch.pipeline().last();
return (TestBulkGetResponse) handler.getResponse(op.id);
}
public TestBulkGetKeysResponse bulkGetKeys() {
return bulkGetKeys(0);
}
public TestBulkGetKeysResponse bulkGetKeys(int scope) {
BulkGetKeysOp op = new BulkGetKeysOp(0xA0, protocolVersion, (byte) 0x1D, defaultCacheName, (byte) 1, 0, scope);
boolean writeFuture = writeOp(op);
// Get the handler instance to retrieve the answer.
ClientHandler handler = (ClientHandler) ch.pipeline().last();
return (TestBulkGetKeysResponse) handler.getResponse(op.id);
}
public TestQueryResponse query(byte[] query) {
QueryOp op = new QueryOp(0xA0, protocolVersion, defaultCacheName, (byte) 1, 0, query);
boolean writeFuture = writeOp(op);
// Get the handler instance to retrieve the answer.
ClientHandler handler = (ClientHandler) ch.pipeline().last();
return (TestQueryResponse) handler.getResponse(op.id);
}
public TestAuthMechListResponse authMechList() {
AuthMechListOp op = new AuthMechListOp(0xA0, protocolVersion, (byte) 0x21, defaultCacheName, (byte) 1, 0);
boolean writeFuture = writeOp(op);
// Get the handler instance to retrieve the answer.
ClientHandler handler = (ClientHandler) ch.pipeline().last();
return (TestAuthMechListResponse) handler.getResponse(op.id);
}
public TestAuthResponse auth(SaslClient sc) throws SaslException {
SaslClient saslClient = sc;
byte[] saslResponse = saslClient.hasInitialResponse() ? saslClient.evaluateChallenge(new byte[0]) : new byte[0];
ClientHandler handler = (ClientHandler) ch.pipeline().last();
AuthOp op = new AuthOp(0xA0, protocolVersion, (byte) 0x23, defaultCacheName, (byte) 1, 0, saslClient.getMechanismName(), saslResponse);
writeOp(op);
TestAuthResponse response = (TestAuthResponse) handler.getResponse(op.id);
while (!saslClient.isComplete() || !response.complete) {
saslResponse = saslClient.evaluateChallenge(response.challenge);
op = new AuthOp(0xA0, protocolVersion, (byte) 0x23, defaultCacheName, (byte) 1, 0, "", saslResponse);
writeOp(op);
response = (TestAuthResponse) handler.getResponse(op.id);
}
saslClient.dispose();
return response;
}
public TestResponse addClientListener(TestClientListener listener, boolean includeState,
Optional<KeyValuePair<String, List<byte[]>>> filterFactory,
Optional<KeyValuePair<String, List<byte[]>>> converterFactory, boolean useRawData) {
AddClientListenerOp op = new AddClientListenerOp(0xA0, protocolVersion, defaultCacheName,
(byte) 1, 0, listener.getId(), includeState, filterFactory, converterFactory, useRawData);
ClientHandler handler = (ClientHandler) ch.pipeline().last();
handler.addClientListener(listener);
writeOp(op);
return handler.getResponse(op.id);
}
public TestResponse removeClientListener(byte[] listenerId) {
RemoveClientListenerOp op = new RemoveClientListenerOp(0xA0, protocolVersion, defaultCacheName, (byte) 1, 0, listenerId);
ClientHandler handler = (ClientHandler) ch.pipeline().last();
writeOp(op);
TestResponse response = handler.getResponse(op.id);
if (response.getStatus() == Success) handler.removeClientListener(listenerId);
return response;
}
public TestSizeResponse size() {
SizeOp op = new SizeOp(0xA0, protocolVersion, defaultCacheName, (byte) 1, 0);
boolean writeFuture = writeOp(op);
// Get the handler instance to retrieve the answer.
ClientHandler handler = (ClientHandler) ch.pipeline().last();
return (TestSizeResponse) handler.getResponse(op.id);
}
public TestGetWithMetadataResponse getStream(byte[] key, int offset) {
GetStreamOp op = new GetStreamOp(0xA0, protocolVersion, defaultCacheName, key, 0, (byte) 1, 0, offset);
writeOp(op);
// Get the handler instance to retrieve the answer.
ClientHandler handler = (ClientHandler) ch.pipeline().last();
return (TestGetWithMetadataResponse) handler.getResponse(op.id);
}
public TestResponse putStream(byte[] key, byte[] value, long version, int lifespan, int maxIdle) {
PutStreamOp op = new PutStreamOp(0xA0, protocolVersion, defaultCacheName, key, value, lifespan, maxIdle, version, (byte)1, 0);
writeOp(op);
ClientHandler handler = (ClientHandler) ch.pipeline().last();
return handler.getResponse(op.id);
}
/*public TestPutStreamResponse putStream(byte[] k, int lifespan, int maxIdle, byte[] v, long dataVersion) {
PutStreamOp op = new PutStreamOp(0xA0, protocolVersion, defaultCacheName, (byte) 1, 0, k, lifespan, maxIdle, v, dataVersion);
writeOp(op);
// Get the handler instance to retrieve the answer.
ClientHandler handler = (ClientHandler) ch.pipeline().last();
return (TestPutStreamResponse) handler.getResponse(op.id);
}*/
}
class ClientChannelInitializer implements NettyInitializer {
private final HotRodClient client;
private final int rspTimeoutSeconds;
private final SSLEngine sslEngine;
private final byte protocolVersion;
ClientChannelInitializer(HotRodClient client, int rspTimeoutSeconds, SSLEngine sslEngine, byte protocolVersion) {
this.client = client;
this.rspTimeoutSeconds = rspTimeoutSeconds;
this.sslEngine = sslEngine;
this.protocolVersion = protocolVersion;
}
@Override
public void initializeChannel(Channel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
if (sslEngine != null)
pipeline.addLast("ssl", new SslHandler(sslEngine));
pipeline.addLast("decoder", new Decoder(client));
pipeline.addLast("encoder", new Encoder(protocolVersion));
pipeline.addLast("handler", new ClientHandler(rspTimeoutSeconds));
}
}
class Encoder extends MessageToByteEncoder<Object> {
private final byte protocolVersion;
private static final Log log = LogFactory.getLog(Encoder.class, Log.class);
Encoder(byte protocolVersion) {
this.protocolVersion = protocolVersion;
}
@Override
protected void encode(ChannelHandlerContext ctx, Object msg, ByteBuf buffer) throws Exception {
log.tracef("Encode %s so that it's sent to the server", msg);
if (msg instanceof PartialOp) {
PartialOp partial = (PartialOp) msg;
buffer.writeByte((byte) partial.magic); // magic
writeUnsignedLong(partial.id, buffer); // message id
buffer.writeByte(partial.version); // version
buffer.writeByte(partial.code); // opcode
} else if (msg instanceof AddClientListenerOp) {
AddClientListenerOp op = (AddClientListenerOp) msg;
writeHeader(op, buffer);
writeRangedBytes(op.listenerId, buffer);
buffer.writeByte(op.includeState ? 1 : 0);
writeNamedFactory(op.filterFactory, buffer);
writeNamedFactory(op.converterFactory, buffer);
if (protocolVersion >= 21)
buffer.writeByte(op.useRawData ? 1 : 0);
} else if (msg instanceof RemoveClientListenerOp) {
RemoveClientListenerOp op = (RemoveClientListenerOp) msg;
writeHeader(op, buffer);
writeRangedBytes(op.listenerId, buffer);
} else if (msg instanceof Op) {
Op op = (Op) msg;
writeHeader(op, buffer);
if (protocolVersion < 20)
writeRangedBytes(new byte[0], buffer); // transaction id
if (op.code != 0x13 && op.code != 0x15
&& op.code != 0x17 && op.code != 0x19
&& op.code != 0x1D && op.code != 0x1F
&& op.code != 0x21 && op.code != 0x23
&& op.code != 0x29) { // if it's a key based op...
writeRangedBytes(op.key, buffer); // key length + key
if (op.code == 0x37) {
// GetStream has an offset
writeUnsignedInt(((GetStreamOp)op).offset, buffer);
}
if (op.value != null) {
if (op.code != 0x0D) { // If it's not removeIfUnmodified...
if (protocolVersion >= 22) {
if (op.lifespan > 0 || op.maxIdle > 0) {
buffer.writeByte(0); // seconds for both
writeUnsignedInt(op.lifespan, buffer); // lifespan
writeUnsignedInt(op.maxIdle, buffer); // maxIdle
} else {
buffer.writeByte(0x88);
}
} else {
writeUnsignedInt(op.lifespan, buffer); // lifespan
writeUnsignedInt(op.maxIdle, buffer); // maxIdle
}
}
if (op.code == 0x09 || op.code == 0x0D || op.code == 0x39) {
buffer.writeLong(op.dataVersion);
}
if (op.code == 0x39) {
// Chunk the value
for(int offset = 0; offset < op.value.length; ) {
int chunk = Math.min(op.value.length - offset, 8192);
writeUnsignedInt(chunk, buffer);
buffer.writeBytes(op.value, offset, chunk);
offset += chunk;
}
writeUnsignedInt(0, buffer);
} else if (op.code != 0x0D) { // If it's not removeIfUnmodified...
writeRangedBytes(op.value, buffer); // value length + value
}
}
} else if (op.code == 0x19) {
writeUnsignedInt(((BulkGetOp) op).count, buffer); // Entry count
} else if (op.code == 0x1D) {
writeUnsignedInt(((BulkGetKeysOp) op).scope, buffer); // Bulk Get Keys Scope
} else if (op.code == 0x1F) {
writeRangedBytes(((QueryOp) op).query, buffer);
} else if (op.code == 0x23) {
AuthOp authop = (AuthOp) op;
if (!authop.mech.isEmpty()) {
writeRangedBytes(authop.mech.getBytes(), buffer);
} else {
writeUnsignedInt(0, buffer);
}
writeRangedBytes(((AuthOp) op).response, buffer);
}
}
}
private void writeNamedFactory(Optional<KeyValuePair<String, List<byte[]>>> namedFactory, ByteBuf buffer) {
if (namedFactory.isPresent()) {
KeyValuePair<String, List<byte[]>> factory = namedFactory.get();
writeString(factory.getKey(), buffer);
buffer.writeByte(factory.getValue().size());
factory.getValue().forEach(bytes -> writeRangedBytes(bytes, buffer));
} else {
buffer.writeByte(0);
}
}
private void writeHeader(Op op, ByteBuf buffer) {
buffer.writeByte(op.magic); // magic
writeUnsignedLong(op.id, buffer); // message id
buffer.writeByte(op.version); // version
buffer.writeByte(op.code); // opcode
if (!op.cacheName.isEmpty()) {
writeRangedBytes(op.cacheName.getBytes(), buffer); // cache name length + cache name
} else {
writeUnsignedInt(0, buffer); // Zero length
}
writeUnsignedInt(op.flags, buffer); // flags
buffer.writeByte(op.clientIntel); // client intelligence
writeUnsignedInt(op.topologyId, buffer); // topology id
}
}
class Decoder extends ReplayingDecoder<Void> {
private final HotRodClient client;
private final static Log log = LogFactory.getLog(Decoder.class, Log.class);
Decoder(HotRodClient client) {
this.client = client;
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List<Object> out) throws Exception {
log.trace("Decode response from server");
buf.readUnsignedByte(); // magic byte
long id = ExtendedByteBuf.readUnsignedLong(buf);
HotRodOperation opCode = HotRodOperation.fromResponseOpCode((byte) buf.readUnsignedByte());
OperationStatus status = OperationStatus.fromCode((byte) buf.readUnsignedByte());
short topologyChangeMarker = buf.readUnsignedByte();
Op op = client.idToOp.get(id);
AbstractTestTopologyAwareResponse topologyChangeResponse;
if (topologyChangeMarker == 1) {
int topologyId = readUnsignedInt(buf);
if (op.clientIntel == Constants.INTELLIGENCE_TOPOLOGY_AWARE) {
int numberClusterMembers = readUnsignedInt(buf);
ServerAddress[] viewArray = new ServerAddress[numberClusterMembers];
for (int i = 0; i < numberClusterMembers; i++) {
String host = readString(buf);
int port = readUnsignedShort(buf);
viewArray[i] = new ServerAddress(host, port);
}
topologyChangeResponse = new TestTopologyAwareResponse(topologyId, Arrays.asList(viewArray));
} else if (op.clientIntel == Constants.INTELLIGENCE_HASH_DISTRIBUTION_AWARE) {
if (op.version < 20)
topologyChangeResponse = read1xHashDistAwareHeader(buf, topologyId, op);
else
topologyChangeResponse = read2xHashDistAwareHeader(buf, topologyId, op);
} else {
throw new UnsupportedOperationException(
"Client intelligence " + op.clientIntel + " not supported");
}
} else {
topologyChangeResponse = null;
}
Response resp;
switch (opCode) {
case STATS:
int size = readUnsignedInt(buf);
Map<String, String> stats = new HashMap<>();
for (int i = 0; i < size; ++i) {
stats.put(readString(buf), readString(buf));
}
resp = new TestStatsResponse(op.version, id, op.cacheName, op.clientIntel,
op.topologyId, topologyChangeResponse, stats);
break;
case PUT:
case PUT_IF_ABSENT:
case REPLACE:
case REPLACE_IF_UNMODIFIED:
case REMOVE:
case REMOVE_IF_UNMODIFIED:
case PUT_STREAM:
boolean checkPrevious;
if (op.version >= 10 && op.version <= 13) {
checkPrevious = (op.flags & ProtocolFlag.ForceReturnPreviousValue.getValue()) == 1;
} else {
checkPrevious = status == SuccessWithPrevious || status == NotExecutedWithPrevious;
}
if (checkPrevious) {
int length = readUnsignedInt(buf);
if (length == 0) {
resp = new TestResponseWithPrevious(op.version, id, op.cacheName,
op.clientIntel, opCode, status, op.topologyId, topologyChangeResponse, Optional.empty());
} else {
byte[] previous = new byte[length];
buf.readBytes(previous);
resp = new TestResponseWithPrevious(op.version, id, op.cacheName,
op.clientIntel, opCode, status, op.topologyId, topologyChangeResponse, Optional.of(previous));
}
} else {
resp = new TestResponse(op.version, id, op.cacheName, op.clientIntel,
opCode, status, op.topologyId, topologyChangeResponse);
}
break;
case CONTAINS_KEY:
case CLEAR:
case PING:
case ADD_CLIENT_LISTENER:
case REMOVE_CLIENT_LISTENER:
resp = new TestResponse(op.version, id, op.cacheName, op.clientIntel, opCode,
status, op.topologyId, topologyChangeResponse);
break;
case GET_WITH_VERSION:
if (status == Success) {
long version = buf.readLong();
Optional<byte[]> data = Optional.of(ExtendedByteBuf.readRangedBytes(buf));
resp = new TestGetWithVersionResponse(op.version, id, op.cacheName,
op.clientIntel, opCode, status, op.topologyId, topologyChangeResponse, data, version);
} else {
resp = new TestGetWithVersionResponse(op.version, id, op.cacheName,
op.clientIntel, opCode, status, op.topologyId, topologyChangeResponse, Optional.empty(), 0);
}
break;
case GET_WITH_METADATA:
case GET_STREAM:
if (status == Success) {
long created = -1;
int lifespan = -1;
long lastUsed = -1;
int maxIdle = -1;
byte flags = buf.readByte();
if ((flags & 0x01) != 0x01) {
created = buf.readLong();
lifespan = readUnsignedInt(buf);
}
if ((flags & 0x02) != 0x02) {
lastUsed = buf.readLong();
maxIdle = readUnsignedInt(buf);
}
long version = buf.readLong();
Optional<byte[]> data = Optional.of(ExtendedByteBuf.readRangedBytes(buf));
resp = new TestGetWithMetadataResponse(op.version, id, op.cacheName,
op.clientIntel, opCode, status, op.topologyId, topologyChangeResponse, data, version,
created, lifespan, lastUsed, maxIdle);
} else {
resp = new TestGetWithMetadataResponse(op.version, id, op.cacheName,
op.clientIntel, opCode, status, op.topologyId, topologyChangeResponse, Optional.empty(), 0,
-1, -1, -1, -1);
}
break;
case GET:
if (status == Success) {
Optional<byte[]> data = Optional.of(ExtendedByteBuf.readRangedBytes(buf));
resp = new TestGetResponse(op.version, id, op.cacheName, op.clientIntel,
opCode, status, op.topologyId, topologyChangeResponse, data);
} else {
resp = new TestGetResponse(op.version, id, op.cacheName, op.clientIntel,
opCode, status, op.topologyId, topologyChangeResponse, Optional.empty());
}
break;
case BULK_GET:
byte done = buf.readByte();
Map<byte[], byte[]> bulkBuffer = new HashMap<>();
while (done == 1) {
bulkBuffer.put(ExtendedByteBuf.readRangedBytes(buf), ExtendedByteBuf.readRangedBytes(buf));
done = buf.readByte();
}
resp = new TestBulkGetResponse(op.version, id, op.cacheName, op.clientIntel,
op.topologyId, topologyChangeResponse, bulkBuffer);
break;
case BULK_GET_KEYS:
done = buf.readByte();
Set<byte[]> bulkKeys = new HashSet<>();
while (done == 1) {
bulkKeys.add(ExtendedByteBuf.readRangedBytes(buf));
done = buf.readByte();
}
resp = new TestBulkGetKeysResponse(op.version, id, op.cacheName, op.clientIntel,
op.topologyId, topologyChangeResponse, bulkKeys);
break;
case QUERY:
byte[] result = ExtendedByteBuf.readRangedBytes(buf);
resp = new TestQueryResponse(op.version, id, op.cacheName, op.clientIntel,
op.topologyId, topologyChangeResponse, result);
break;
case AUTH_MECH_LIST:
size = readUnsignedInt(buf);
Set<String> mechs = new HashSet<>();
for (int i = 0; i < size; ++i) {
mechs.add(readString(buf));
}
resp = new TestAuthMechListResponse(op.version, id, op.cacheName, op.clientIntel,
op.topologyId, topologyChangeResponse, mechs);
break;
case AUTH: {
boolean complete = buf.readBoolean();
byte[] challenge = ExtendedByteBuf.readRangedBytes(buf);
resp = new TestAuthResponse(op.version, id, op.cacheName, op.clientIntel,
op.topologyId, topologyChangeResponse, complete, challenge);
break;
}
case CACHE_ENTRY_CREATED_EVENT:
case CACHE_ENTRY_MODIFIED_EVENT:
case CACHE_ENTRY_REMOVED_EVENT:
byte[] listenerId = ExtendedByteBuf.readRangedBytes(buf);
byte isCustom = buf.readByte();
boolean isRetried = buf.readByte() == 1;
if (isCustom == 1 || isCustom == 2) {
byte[] eventData = ExtendedByteBuf.readRangedBytes(buf);
resp = new TestCustomEvent(client.protocolVersion, id, client.defaultCacheName, opCode, listenerId,
isRetried, eventData);
} else {
byte[] key = ExtendedByteBuf.readRangedBytes(buf);
if (opCode == HotRodOperation.CACHE_ENTRY_REMOVED_EVENT) {
resp = new TestKeyEvent(client.protocolVersion, id, client.defaultCacheName, listenerId, isRetried, key);
} else {
long dataVersion = buf.readLong();
resp = new TestKeyWithVersionEvent(client.protocolVersion, id, client.defaultCacheName,
opCode, listenerId, isRetried, key, dataVersion);
}
}
break;
case SIZE:
long lsize = ExtendedByteBuf.readUnsignedLong(buf);
resp = new TestSizeResponse(op.version, id, op.cacheName, op.clientIntel,
op.topologyId, topologyChangeResponse, lsize);
break;
case ERROR:
if (op == null)
resp = new TestErrorResponse((byte) 10, id, "", (short) 0, status, 0,
topologyChangeResponse, readString(buf));
else
resp = new TestErrorResponse(op.version, id, op.cacheName, op.clientIntel,
status, op.topologyId, topologyChangeResponse, readString(buf));
break;
default:
resp = null;
break;
}
if (resp != null) {
log.tracef("Got response from server: %s", resp);
out.add(resp);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
log.exceptionReported(cause);
}
private AbstractTestTopologyAwareResponse read2xHashDistAwareHeader(ByteBuf buf, int topologyId, Op op) {
int numServersInTopo = readUnsignedInt(buf);
List<ServerAddress> members = new ArrayList<>();
for (int i = 0; i < numServersInTopo; ++i) {
ServerAddress node = new ServerAddress(readString(buf), readUnsignedShort(buf));
members.add(node);
}
byte hashFunction = buf.readByte();
int numSegments = readUnsignedInt(buf);
List<Iterable<ServerAddress>> segments = new ArrayList<>();
if (hashFunction > 0) {
for (int i = 1; i <= numSegments; ++i) {
byte owners = buf.readByte();
List<ServerAddress> membersInSegment = new ArrayList<>();
for (int j = 1; j <= owners; ++j) {
int index = readUnsignedInt(buf);
membersInSegment.add(members.get(index));
}
segments.add(membersInSegment);
}
}
return new TestHashDistAware20Response(topologyId, members, segments, hashFunction);
}
private AbstractTestTopologyAwareResponse read1xHashDistAwareHeader(ByteBuf buf, int topologyId, Op op) {
int numOwners = readUnsignedShort(buf);
byte hashFunction = buf.readByte();
int hashSpace = readUnsignedInt(buf);
int numServersInTopo = readUnsignedInt(buf);
if (op.version == 10) {
return read10HashDistAwareHeader(buf, topologyId,
numOwners, hashFunction, hashSpace, numServersInTopo);
} else {
return read11HashDistAwareHeader(buf, topologyId,
numOwners, hashFunction, hashSpace, numServersInTopo);
}
}
private AbstractTestTopologyAwareResponse read10HashDistAwareHeader(ByteBuf buf, int topologyId,
int numOwners, byte hashFunction, int hashSpace, int numServersInTopo) {
// The exact number of topology addresses in the list is unknown
// until we loop through the entire list and we figure out how
// hash ids are per HotRod server (i.e. num virtual nodes > 1)
Set<ServerAddress> members = new HashSet<>();
Map<ServerAddress, List<Integer>> allHashIds = new HashMap<>();
List<Integer> hashIdsOfAddr = new ArrayList<>();
ServerAddress prevNode = null;
for (int i = 1; i <= numServersInTopo; ++i) {
ServerAddress node = new ServerAddress(readString(buf), readUnsignedShort(buf));
int hashId = buf.readInt();
if (prevNode == null || node.equals(prevNode)) {
// First time node has been seen, so cache it
if (prevNode == null)
prevNode = node;
// Add current hash id to list
hashIdsOfAddr.add(hashId);
} else {
// A new node has been detected, so create the topology
// address and store it in the view
allHashIds.put(prevNode, hashIdsOfAddr);
members.add(prevNode);
prevNode = node;
hashIdsOfAddr = new ArrayList<>();
hashIdsOfAddr.add(hashId);
}
// Check for last server hash in which case just add it
if (i == numServersInTopo) {
allHashIds.put(prevNode, hashIdsOfAddr);
members.add(prevNode);
}
}
return new TestHashDistAware10Response(topologyId, members,
allHashIds, numOwners, hashFunction, hashSpace);
}
private AbstractTestTopologyAwareResponse read11HashDistAwareHeader(ByteBuf buf, int topologyId,
int numOwners, Byte hashFunction, int hashSpace,
int numServersInTopo) {
int numVirtualNodes = readUnsignedInt(buf);
Map<ServerAddress, Integer> hashToAddress = new HashMap<>();
for (int i = 1; i <= numServersInTopo; ++i) {
hashToAddress.put(new ServerAddress(readString(buf), readUnsignedShort(buf)), buf.readInt());
}
return new TestHashDistAware11Response(topologyId, hashToAddress,
numOwners, hashFunction, hashSpace, numVirtualNodes);
}
}
class ClientHandler extends ChannelInboundHandlerAdapter {
private static final Log log = LogFactory.getLog(ClientHandler.class, Log.class);
final int rspTimeoutSeconds;
ClientHandler(int rspTimeoutSeconds) {
this.rspTimeoutSeconds = rspTimeoutSeconds;
}
private Map<Long, TestResponse> responses = new ConcurrentHashMap<>();
private Map<WrappedByteArray, TestClientListener> clientListeners = new ConcurrentHashMap<>();
void addClientListener(TestClientListener listener) {
clientListeners.put(new WrappedByteArray(listener.getId()), listener);
}
void removeClientListener(byte[] listenerId) {
clientListeners.remove(new WrappedByteArray(listenerId));
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof TestKeyWithVersionEvent) {
TestKeyWithVersionEvent e = (TestKeyWithVersionEvent) msg;
switch (e.getOperation()) {
case CACHE_ENTRY_CREATED_EVENT:
clientListeners.get(new WrappedByteArray(e.listenerId)).onCreated(e);
break;
case CACHE_ENTRY_MODIFIED_EVENT:
clientListeners.get(new WrappedByteArray(e.listenerId)).onModified(e);
break;
}
} else if (msg instanceof TestKeyEvent) {
TestKeyEvent e = (TestKeyEvent) msg;
clientListeners.get(new WrappedByteArray(e.listenerId)).onRemoved(e);
} else if (msg instanceof TestCustomEvent) {
TestCustomEvent e = (TestCustomEvent) msg;
clientListeners.get(new WrappedByteArray(e.listenerId)).onCustom(e);
} else if (msg instanceof TestResponse) {
TestResponse resp = (TestResponse) msg;
log.tracef("Put %s in responses", resp);
responses.put(resp.getMessageId(), resp);
} else {
throw new IllegalArgumentException("Unsupport object: " + msg);
}
}
TestResponse getResponse(long messageId) {
// Very TODO very primitive way of waiting for a response. Convert to a Future
int i = 0;
TestResponse v;
do {
v = responses.get(messageId);
if (v == null) {
TestingUtil.sleepThread(100);
i += 1;
}
}
while (v == null && i < (rspTimeoutSeconds * 10));
return v;
}
}
class PartialOp extends Op {
public PartialOp(int magic, byte version, byte code, String cacheName, byte[] key, int lifespan, int maxIdle,
byte[] value, int flags, long dataVersion, byte clientIntel, int topologyId) {
super(magic, version, code, cacheName, key, lifespan, maxIdle, value, flags, dataVersion, clientIntel, topologyId);
}
}
abstract class AbstractOp extends Op {
public AbstractOp(int magic, byte version, byte code, String cacheName, byte clientIntel, int topologyId) {
super(magic, version, code, cacheName, null, 0, 0, null, 0, 0, clientIntel, topologyId);
}
}
class StatsOp extends AbstractOp {
final String statName;
public StatsOp(int magic, byte version, byte code, String cacheName, byte clientIntel, int topologyId, String statName) {
super(magic, version, code, cacheName, clientIntel, topologyId);
this.statName = statName;
}
}
class BulkGetOp extends AbstractOp {
final int count;
public BulkGetOp(int magic, byte version, byte code, String cacheName, byte clientIntel, int topologyId, int count) {
super(magic, version, code, cacheName, clientIntel, topologyId);
this.count = count;
}
}
class BulkGetKeysOp extends AbstractOp {
final int scope;
public BulkGetKeysOp(int magic, byte version, byte code, String cacheName, byte clientIntel, int topologyId, int scope) {
super(magic, version, code, cacheName, clientIntel, topologyId);
this.scope = scope;
}
}
class QueryOp extends AbstractOp {
final byte[] query;
public QueryOp(int magic, byte version, String cacheName, byte clientIntel, int topologyId, byte[] query) {
super(magic, version, (byte) 0x1F, cacheName, clientIntel, topologyId);
this.query = query;
}
}
class AddClientListenerOp extends AbstractOp {
final byte[] listenerId;
final boolean includeState;
final Optional<KeyValuePair<String, List<byte[]>>> filterFactory;
final Optional<KeyValuePair<String, List<byte[]>>> converterFactory;
final boolean useRawData;
public AddClientListenerOp(int magic, byte version, String cacheName, byte clientIntel, int topologyId,
byte[] listenerId, boolean includeState, Optional<KeyValuePair<String, List<byte[]>>> filterFactory,
Optional<KeyValuePair<String, List<byte[]>>> converterFactory, boolean useRawData) {
super(magic, version, (byte) 0x25, cacheName, clientIntel, topologyId);
this.listenerId = listenerId;
this.includeState = includeState;
this.filterFactory = filterFactory;
this.converterFactory = converterFactory;
this.useRawData = useRawData;
}
}
class RemoveClientListenerOp extends AbstractOp {
final byte[] listenerId;
public RemoveClientListenerOp(int magic, byte version, String cacheName, byte clientIntel, int topologyId,
byte[] listenerId) {
super(magic, version, (byte) 0x27, cacheName, clientIntel, topologyId);
this.listenerId = listenerId;
}
}
class AuthMechListOp extends AbstractOp {
public AuthMechListOp(int magic, byte version, byte code, String cacheName, byte clientIntel, int topologyId) {
super(magic, version, code, cacheName, clientIntel, topologyId);
}
}
class AuthOp extends AbstractOp {
final String mech;
final byte[] response;
public AuthOp(int magic, byte version, byte code, String cacheName, byte clientIntel, int topologyId, String mech,
byte[] response) {
super(magic, version, code, cacheName, clientIntel, topologyId);
this.mech = mech;
this.response = response;
}
}
class SizeOp extends AbstractOp {
public SizeOp(int magic, byte version, String cacheName, byte clientIntel, int topologyId) {
super(magic, version, (byte) 0x29, cacheName, clientIntel, topologyId);
}
}
class GetStreamOp extends Op {
final int offset;
public GetStreamOp(int magic, byte version, String cacheName, byte[] key, int flags, byte clientIntel, int topologyId, int offset) {
super(magic, version, (byte)0x37, cacheName, key, -1, -1, null, flags, 0, clientIntel, topologyId);
this.offset = offset;
}
}
class PutStreamOp extends Op {
public PutStreamOp(int magic, byte version, String cacheName, byte[] key, byte[] value, int lifespan, int maxIdle, long dataVersion, byte clientIntel, int topologyId) {
super(magic, version, (byte)0x39, cacheName, key, lifespan, maxIdle, value, 0, dataVersion, clientIntel, topologyId);
}
}
class ServerNode {
final String host;
final int port;
ServerNode(String host, int port) {
this.host = host;
this.port = port;
}
}