package org.infinispan.persistence.rest;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletionService;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import org.apache.commons.codec.EncoderException;
import org.apache.commons.codec.net.URLCodec;
import org.infinispan.commons.configuration.ConfiguredBy;
import org.infinispan.commons.persistence.Store;
import org.infinispan.commons.util.Util;
import org.infinispan.container.InternalEntryFactory;
import org.infinispan.executors.ExecutorAllCompletionService;
import org.infinispan.filter.KeyFilter;
import org.infinispan.marshall.core.MarshalledEntry;
import org.infinispan.metadata.InternalMetadata;
import org.infinispan.metadata.Metadata;
import org.infinispan.metadata.impl.InternalMetadataImpl;
import org.infinispan.persistence.TaskContextImpl;
import org.infinispan.persistence.keymappers.MarshallingTwoWayKey2StringMapper;
import org.infinispan.persistence.rest.configuration.ConnectionPoolConfiguration;
import org.infinispan.persistence.rest.configuration.RestStoreConfiguration;
import org.infinispan.persistence.rest.logging.Log;
import org.infinispan.persistence.rest.metadata.MetadataHelper;
import org.infinispan.persistence.spi.AdvancedLoadWriteStore;
import org.infinispan.persistence.spi.InitializationContext;
import org.infinispan.persistence.spi.PersistenceException;
import org.infinispan.util.logging.LogFactory;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import net.jcip.annotations.ThreadSafe;
/**
* RestStore.
*
* @author Tristan Tarrant
* @since 6.0
*/
@Store(shared = true)
@ThreadSafe
@ConfiguredBy(RestStoreConfiguration.class)
public class RestStore implements AdvancedLoadWriteStore {
private static final String MAX_IDLE_TIME_SECONDS = "maxIdleTimeSeconds";
private static final String TIME_TO_LIVE_SECONDS = "timeToLiveSeconds";
private static final Log log = LogFactory.getLog(RestStore.class, Log.class);
private volatile RestStoreConfiguration configuration;
private Bootstrap bootstrap;
private InternalEntryFactory iceFactory;
private MarshallingTwoWayKey2StringMapper key2StringMapper;
private String path;
private MetadataHelper metadataHelper;
private final URLCodec urlCodec = new URLCodec();
private InitializationContext ctx;
private EventLoopGroup workerGroup;
private int maxContentLength;
@Override
public void init(InitializationContext initializationContext) {
configuration = initializationContext.getConfiguration();
ctx = initializationContext;
}
@Override
public void start() {
if (iceFactory == null) {
iceFactory = ctx.getCache().getAdvancedCache().getComponentRegistry().getComponent(InternalEntryFactory.class);
}
ConnectionPoolConfiguration pool = configuration.connectionPool();
workerGroup = new NioEventLoopGroup();
Bootstrap b = new Bootstrap().group(workerGroup).channel(NioSocketChannel.class);
b.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
ch.pipeline().addLast(new HttpClientCodec());
}
});
b.option(ChannelOption.SO_KEEPALIVE, true); // TODO make this part of configuration options
b.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, pool.connectionTimeout());
b.option(ChannelOption.SO_SNDBUF, pool.bufferSize());// TODO make sure this is appropriate
b.option(ChannelOption.SO_RCVBUF, pool.bufferSize());
b.option(ChannelOption.TCP_NODELAY, pool.tcpNoDelay());
bootstrap = b;
maxContentLength = 10 * 1024 * 1024; // TODO make this part of configuration options.
this.key2StringMapper = Util.getInstance(configuration.key2StringMapper(), ctx.getCache().getAdvancedCache().getClassLoader());
this.key2StringMapper.setMarshaller(ctx.getMarshaller());
this.path = configuration.path();
try {
if (configuration.appendCacheNameToPath()) {
path = path + urlCodec.encode(ctx.getCache().getName()) + "/";
}
} catch (EncoderException e) {
}
this.metadataHelper = Util.getInstance(configuration.metadataHelper(), ctx.getCache().getAdvancedCache().getClassLoader());
/*
* HACK ALERT. Initialize some internal Netty structures early while we are within the scope of the correct classloader to
* avoid triggering ISPN-7602
* This needs to be fixed properly within the context of ISPN-7601
*/
new HttpResponseHandler();
}
@Override
public void stop() {
workerGroup.shutdownGracefully();
}
public void setInternalCacheEntryFactory(InternalEntryFactory iceFactory) {
if (this.iceFactory != null) {
throw new IllegalStateException();
}
this.iceFactory = iceFactory;
}
private String keyToUri(Object key) {
try {
return path + urlCodec.encode(key2StringMapper.getStringMapping(key));
} catch (EncoderException e) {
throw new PersistenceException(e);
}
}
private byte[] marshall(String contentType, MarshalledEntry entry) throws IOException, InterruptedException {
if (configuration.rawValues()) {
return (byte[]) entry.getValue();
} else {
if (isTextContentType(contentType)) {
return (byte[]) entry.getValue();
}
return ctx.getMarshaller().objectToByteBuffer(entry.getValue());
}
}
private Object unmarshall(String contentType, byte[] b) throws IOException, ClassNotFoundException {
if (configuration.rawValues()) {
return b;
} else {
if (isTextContentType(contentType)) {
return new String(b); // TODO: use response header Content Encoding
} else {
return ctx.getMarshaller().objectFromByteBuffer(b);
}
}
}
private boolean isTextContentType(String contentType) {
return contentType.startsWith("text/") || "application/xml".equals(contentType) || "application/json".equals(contentType);
}
@Override
public void write(MarshalledEntry entry) {
try {
String contentType = metadataHelper.getContentType(entry);
ByteBuf content = Unpooled.wrappedBuffer(marshall(contentType, entry));
DefaultFullHttpRequest put = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, keyToUri(entry.getKey()), content);
put.headers().add("Content-Type", contentType);
put.headers().add("Content-Length", content.readableBytes());
InternalMetadata metadata = entry.getMetadata();
if (metadata != null && metadata.expiryTime() > -1) {
put.headers().add(TIME_TO_LIVE_SECONDS, Long.toString(timeoutToSeconds(metadata.lifespan())));
put.headers().add(MAX_IDLE_TIME_SECONDS, Long.toString(timeoutToSeconds(metadata.maxIdle())));
}
Channel ch = bootstrap.connect(configuration.host(), configuration.port()).awaitUninterruptibly().channel().pipeline().addLast(new HttpResponseHandler()).channel();
ch.writeAndFlush(put).sync().channel().closeFuture().sync();
} catch (Exception e) {
throw new PersistenceException(e);
}
}
private static class HttpResponseHandler extends SimpleChannelInboundHandler<HttpResponse> {
private FullHttpResponse response;
private boolean retainResponse;
public HttpResponseHandler() {
this(false);
}
public HttpResponseHandler(boolean retainResponse) {
this.retainResponse = retainResponse;
}
protected void channelRead0(ChannelHandlerContext ctx, HttpResponse msg) throws Exception {
if (retainResponse) {
this.response = ((FullHttpResponse) msg).retain();
}
ctx.close();
};
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
throw new PersistenceException(cause);
}
public FullHttpResponse getResponse() {
return response;
}
}
@Override
public void clear() {
DefaultHttpRequest delete = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.DELETE, path);
try {
Channel ch = bootstrap.connect(configuration.host(), configuration.port()).awaitUninterruptibly().channel().pipeline().addLast(new HttpResponseHandler()).channel();
ch.writeAndFlush(delete).sync().channel().closeFuture().sync();
} catch (Exception e) {
throw new PersistenceException(e);
}
}
@Override
public boolean delete(Object key) {
DefaultHttpRequest delete = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.DELETE, keyToUri(key));
try {
HttpResponseHandler handler = new HttpResponseHandler(true);
Channel ch = bootstrap.connect(configuration.host(), configuration.port()).awaitUninterruptibly().channel().pipeline().addLast(new HttpObjectAggregator(maxContentLength), handler).channel();
ch.writeAndFlush(delete).sync().channel().closeFuture().sync();
try {
return isSuccessful(handler.getResponse().status().code());
} finally {
handler.getResponse().release();
}
} catch (Exception e) {
throw new PersistenceException(e);
}
}
@Override
public MarshalledEntry load(Object key) {
try {
DefaultHttpRequest get = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, keyToUri(key));
HttpResponseHandler handler = new HttpResponseHandler(true);
Channel ch = bootstrap.connect(configuration.host(), configuration.port()).awaitUninterruptibly().channel().pipeline().addLast(new HttpObjectAggregator(maxContentLength), handler).channel();
ch.writeAndFlush(get).sync().channel().closeFuture().sync();
FullHttpResponse response = handler.getResponse();
try {
if (HttpResponseStatus.OK.equals(response.status())) {
String contentType = response.headers().get(HttpHeaderNames.CONTENT_TYPE);
long ttl = timeHeaderToSeconds(response.headers().get(TIME_TO_LIVE_SECONDS));
long maxidle = timeHeaderToSeconds(response.headers().get(MAX_IDLE_TIME_SECONDS));
Metadata metadata = metadataHelper.buildMetadata(contentType, ttl, TimeUnit.SECONDS, maxidle, TimeUnit.SECONDS);
InternalMetadata internalMetadata;
if (metadata.maxIdle() > -1 || metadata.lifespan() > -1) {
long now = ctx.getTimeService().wallClockTime();
internalMetadata = new InternalMetadataImpl(metadata, now, now);
} else {
internalMetadata = new InternalMetadataImpl(metadata, -1, -1);
}
ByteBuf content = response.content();
byte[] bytes = new byte[content.readableBytes()];
content.readBytes(bytes);
return ctx.getMarshalledEntryFactory().newMarshalledEntry(key, unmarshall(contentType, bytes), internalMetadata);
} else if (HttpResponseStatus.NOT_FOUND.equals(response.status())) {
return null;
} else {
throw log.httpError(response.status().toString());
}
} finally {
response.release();
}
} catch (IOException e) {
throw log.httpError(e);
} catch (Exception e) {
throw new PersistenceException(e);
}
}
private long timeoutToSeconds(long timeout) {
if (timeout < 0)
return -1;
else if (timeout > 0 && timeout < 1000)
return 1;
else
return TimeUnit.MILLISECONDS.toSeconds(timeout);
}
private long timeHeaderToSeconds(String header) {
return header == null ? -1 : Long.parseLong(header);
}
@Override
public void process(KeyFilter keyFilter, final CacheLoaderTask cacheLoaderTask, Executor executor, boolean loadValue, boolean loadMetadata) {
DefaultHttpRequest get = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path + "?global");
get.headers().add(HttpHeaderNames.ACCEPT, "text/plain");
get.headers().add(HttpHeaderNames.ACCEPT_CHARSET, "UTF-8");
try {
HttpResponseHandler handler = new HttpResponseHandler(true);
Channel ch = bootstrap.connect(configuration.host(), configuration.port()).awaitUninterruptibly().channel().pipeline().addLast(
new HttpObjectAggregator(maxContentLength), handler).channel();
ch.writeAndFlush(get).sync().channel().closeFuture().sync();
int batchSize = 1000;
ExecutorAllCompletionService eacs = new ExecutorAllCompletionService(executor);
final TaskContext taskContext = new TaskContextImpl();
try {
BufferedReader reader = new BufferedReader(new InputStreamReader(new ByteBufInputStream(((FullHttpResponse) handler.getResponse()).content()), "UTF-8"));
try {
Set<Object> entries = new HashSet<Object>(batchSize);
for (String stringKey = reader.readLine(); stringKey != null; stringKey = reader.readLine()) {
Object key = key2StringMapper.getKeyMapping(stringKey);
if (keyFilter == null || keyFilter.accept(key))
entries.add(key);
if (entries.size() == batchSize) {
final Set<Object> batch = entries;
entries = new HashSet<Object>(batchSize);
submitProcessTask(cacheLoaderTask, eacs, taskContext, batch, loadValue, loadMetadata);
}
}
if (!entries.isEmpty()) {
submitProcessTask(cacheLoaderTask, eacs, taskContext, entries, loadValue, loadMetadata);
}
} finally {
reader.close();
}
eacs.waitUntilAllCompleted();
if (eacs.isExceptionThrown()) {
throw new PersistenceException("Execution exception!", eacs.getFirstException());
}
} finally {
handler.getResponse().release();
}
} catch (Exception e) {
throw log.errorLoadingRemoteEntries(e);
}
}
private void submitProcessTask(final CacheLoaderTask cacheLoaderTask, CompletionService ecs,
final TaskContext taskContext, final Set<Object> batch, final boolean loadEntry,
final boolean loadMetadata) {
ecs.submit(new Callable<Void>() {
@Override
public Void call() throws Exception {
try {
for (Object key : batch) {
if (taskContext.isStopped())
break;
MarshalledEntry entry = null;
if (loadEntry || loadMetadata) {
entry = load(key);
}
if (!loadEntry || !loadMetadata) {
entry = ctx.getMarshalledEntryFactory().newMarshalledEntry(key,
loadEntry ? entry.getValue() : null,
loadMetadata ? entry.getMetadata() : null);
}
cacheLoaderTask.processEntry(entry, taskContext);
}
} catch (Exception e) {
log.errorExecutingParallelStoreTask(e);
throw e;
}
return null;
}
});
}
@Override
public void purge(Executor executor, PurgeListener purgeListener) {
// This should be handled by the remote server
}
@Override
public int size() {
Channel ch = null;
DefaultHttpRequest get = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path + "?global");
get.headers().add(HttpHeaders.Names.ACCEPT, "text/plain");
try {
HttpResponseHandler handler = new HttpResponseHandler(true);
ch = bootstrap.connect(configuration.host(), configuration.port()).awaitUninterruptibly().channel().pipeline().addLast(
new HttpObjectAggregator(maxContentLength), handler).channel();
ch.writeAndFlush(get).sync().channel().closeFuture().sync();
try {
BufferedReader reader = null;
try {
reader = new BufferedReader(new InputStreamReader(new ByteBufInputStream(((FullHttpResponse) handler.getResponse()).content())));
int count = 0;
while (reader.readLine() != null)
count++;
return count;
} finally {
reader.close();
}
} finally {
handler.getResponse().release();
}
} catch (Exception e) {
throw log.errorLoadingRemoteEntries(e);
}
}
@Override
public boolean contains(Object o) {
return load(o) != null;
}
private boolean isSuccessful(int status) {
return status >= 200 && status < 300;
}
}