/*
* This file is provided to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.jbrisbin.riak.async;
import static com.google.protobuf.ByteString.*;
import static com.jbrisbin.riak.pbc.RiakMessageCodes.*;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import com.basho.riak.client.IRiakObject;
import com.basho.riak.client.bucket.BucketProperties;
import com.basho.riak.client.builders.BucketPropertiesBuilder;
import com.basho.riak.client.query.MapReduceResult;
import com.basho.riak.client.query.WalkResult;
import com.basho.riak.client.raw.RiakResponse;
import com.basho.riak.client.raw.StoreMeta;
import com.basho.riak.client.raw.query.LinkWalkSpec;
import com.basho.riak.client.raw.query.MapReduceTimeoutException;
import com.basho.riak.client.util.CharsetUtils;
import com.basho.riak.pbc.RPB;
import com.google.protobuf.ByteString;
import com.jbrisbin.riak.async.raw.RawAsyncClient;
import com.jbrisbin.riak.async.raw.ServerInfo;
import com.jbrisbin.riak.async.util.DelegatingErrorHandler;
import com.jbrisbin.riak.pbc.RiakObject;
import com.jbrisbin.riak.pbc.RpbFilter;
import com.jbrisbin.riak.pbc.RpbMessage;
import org.apache.commons.codec.binary.Base64;
import org.codehaus.jackson.map.ObjectMapper;
import org.glassfish.grizzly.Connection;
import org.glassfish.grizzly.EmptyCompletionHandler;
import org.glassfish.grizzly.filterchain.BaseFilter;
import org.glassfish.grizzly.filterchain.FilterChain;
import org.glassfish.grizzly.filterchain.FilterChainBuilder;
import org.glassfish.grizzly.filterchain.FilterChainContext;
import org.glassfish.grizzly.filterchain.NextAction;
import org.glassfish.grizzly.filterchain.TransportFilter;
import org.glassfish.grizzly.memory.HeapMemoryManager;
import org.glassfish.grizzly.nio.transport.TCPNIOTransport;
import org.glassfish.grizzly.nio.transport.TCPNIOTransportBuilder;
import org.glassfish.grizzly.strategies.WorkerThreadIOStrategy;
import org.glassfish.grizzly.threadpool.ThreadPoolConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @author Jon Brisbin <jon@jbrisbin.com>
*/
public class RiakAsyncClient implements RawAsyncClient {
private final Logger log = LoggerFactory.getLogger(getClass());
private final int PROCESSORS = Runtime.getRuntime().availableProcessors();
private String host = "localhost";
private Integer port = 8087;
private Long timeout = 30L;
private DelegatingErrorHandler errorHandler = new DelegatingErrorHandler();
private ExecutorService workerPool = Executors.newFixedThreadPool(PROCESSORS);
private HeapMemoryManager heap = new HeapMemoryManager();
private TCPNIOTransport transport = TCPNIOTransportBuilder.newInstance().build();
private FilterChain filterChain;
private Connection connection;
private int maxConnectionRetries = 5;
private AtomicInteger retries = new AtomicInteger(0);
private LinkedBlockingQueue<RpbRequest> pendingRequests = new LinkedBlockingQueue<RpbRequest>();
private ObjectMapper mapper = new ObjectMapper();
public RiakAsyncClient() {
start();
}
public RiakAsyncClient(String host, Integer port) {
this.host = host;
this.port = port;
start();
}
public RiakAsyncClient(String host, Integer port, Long timeout) {
this.host = host;
this.port = port;
this.timeout = timeout;
}
public RiakAsyncClient registerErrorHandler(Class<? extends Throwable> t, ErrorHandler errorHandler) {
this.errorHandler.registerErrorHandler(t, errorHandler);
return this;
}
public String getHost() {
return host;
}
public RiakAsyncClient setHost(String host) {
this.host = host;
return this;
}
public Integer getPort() {
return port;
}
public RiakAsyncClient setPort(Integer port) {
this.port = port;
return this;
}
public Long getTimeout() {
return timeout;
}
public RiakAsyncClient setTimeout(Long timeout) {
this.timeout = timeout;
return this;
}
public ExecutorService getWorkerPool() {
return workerPool;
}
public RiakAsyncClient setWorkerPool(ExecutorService workerPool) {
this.workerPool = workerPool;
return this;
}
public RiakAsyncClient setErrorHandler(DelegatingErrorHandler errorHandler) {
this.errorHandler = errorHandler;
return this;
}
public DelegatingErrorHandler getErrorHandler() {
return errorHandler;
}
@SuppressWarnings({"unchecked"})
public void close() {
try {
connection.close(new EmptyCompletionHandler<Connection>() {
@Override public void failed(Throwable throwable) {
errorHandler.handleError(throwable);
}
@Override public void completed(Connection result) {
try {
transport.stop();
} catch (IOException e) {
failed(e);
}
if (log.isDebugEnabled())
log.debug(String.format("Disconnected from %s:%s", host, port));
}
});
} catch (IOException e) {
errorHandler.handleError(e);
}
}
private void start() {
FilterChainBuilder clientChainBuilder = FilterChainBuilder.stateless();
clientChainBuilder.add(new TransportFilter());
clientChainBuilder.add(new RpbFilter(heap));
clientChainBuilder.add(new PendingRequestFilter());
filterChain = clientChainBuilder.build();
transport.setKeepAlive(true);
transport.setTcpNoDelay(true);
ThreadPoolConfig config = ThreadPoolConfig
.defaultConfig()
.setPoolName("riak-async-client")
.setCorePoolSize(PROCESSORS)
.setMaxPoolSize(PROCESSORS)
.setKeepAliveTime(timeout, TimeUnit.SECONDS);
transport.setWorkerThreadPoolConfig(config);
transport.setIOStrategy(WorkerThreadIOStrategy.getInstance());
transport.setProcessor(filterChain);
try {
transport.start();
connection = getConnection();
} catch (IOException e) {
errorHandler.handleError(e);
} catch (InterruptedException e) {
errorHandler.handleError(e);
} catch (ExecutionException e) {
errorHandler.handleError(e);
} catch (TimeoutException e) {
errorHandler.handleError(e);
}
}
private Connection getConnection() throws IOException, ExecutionException, TimeoutException, InterruptedException {
if (null == connection || !connection.isOpen()) {
if (retries.get() < maxConnectionRetries) {
connection = transport.connect(new InetSocketAddress(host, port), new EmptyCompletionHandler<Connection>() {
@Override public void failed(Throwable throwable) {
errorHandler.handleError(throwable);
retries.incrementAndGet();
}
@Override public void completed(Connection result) {
if (retries.get() > 0)
retries.decrementAndGet();
}
}).get(timeout, TimeUnit.SECONDS);
if (log.isDebugEnabled())
log.debug(String.format("Connected to %s:%s", host, port));
generateAndSetClientId().get(timeout, TimeUnit.SECONDS);
}
}
return connection;
}
@Override public Promise<RiakResponse> fetch(String bucket, String key) throws IOException {
return fetch(bucket, key, -1);
}
@SuppressWarnings({"unchecked"})
@Override public Promise<RiakResponse> fetch(String bucket, String key, int readQuorum) throws IOException {
RPB.RpbGetReq.Builder b = RPB.RpbGetReq.newBuilder()
.setBucket(copyFromUtf8(bucket))
.setKey(copyFromUtf8(key));
if (readQuorum > 0) {
b.setR(readQuorum);
}
RpbMessage<RPB.RpbGetReq> msg = new RpbMessage<RPB.RpbGetReq>(MSG_GetReq, b.build());
final Promise<RiakResponse> promise = new Promise<RiakResponse>();
try {
getConnection().write(new RpbRequest(msg, promise));
} catch (Exception e) {
errorHandler.handleError(e);
}
return promise;
}
@SuppressWarnings({"unchecked"})
@Override
public Promise<RiakResponse> store(final IRiakObject object, final StoreMeta storeMeta) {
RPB.RpbPutReq.Builder b = RPB.RpbPutReq.newBuilder()
.setBucket(copyFromUtf8(object.getBucket()))
.setKey(copyFromUtf8(object.getKey()));
if (null != object.getVClock())
b.setVclock(copyFrom(object.getVClock().getBytes()));
if (object instanceof RiakObject) {
b.setContent(((RiakObject) object).build());
}
if (null != storeMeta) {
if (null != storeMeta.getReturnBody())
b.setReturnBody(storeMeta.getReturnBody());
if (null != storeMeta.getDw())
b.setDw(storeMeta.getDw());
if (null != storeMeta.getW())
b.setW(storeMeta.getW());
}
RpbMessage<RPB.RpbPutReq> msg = new RpbMessage<RPB.RpbPutReq>(MSG_PutReq, b.build());
final Promise<RiakResponse> promise = new Promise<RiakResponse>();
try {
getConnection().write(new RpbRequest(msg, promise), new EmptyCompletionHandler() {
@Override public void failed(Throwable throwable) {
retries.incrementAndGet();
if (retries.get() < maxConnectionRetries) {
if (log.isDebugEnabled())
log.debug("Retrying request: " + object);
store(object, storeMeta);
}
}
});
} catch (Exception e) {
errorHandler.handleError(e);
}
return promise;
}
@Override public Promise<RiakResponse> store(IRiakObject object) throws IOException {
return store(object, null);
}
@Override public Promise<Void> delete(String bucket, String key) throws IOException {
return delete(bucket, key, -1);
}
@SuppressWarnings({"unchecked"})
@Override public Promise<Void> delete(String bucket, String key, int deleteQuorum) {
RPB.RpbDelReq.Builder b = RPB.RpbDelReq.newBuilder()
.setBucket(copyFromUtf8(bucket))
.setKey(copyFromUtf8(key));
if (deleteQuorum > -1)
b.setRw(deleteQuorum);
RpbMessage<RPB.RpbDelReq> msg = new RpbMessage<RPB.RpbDelReq>(MSG_DelReq, b.build());
final Promise<Void> promise = new Promise<Void>();
try {
getConnection().write(new RpbRequest(msg, promise));
} catch (Exception e) {
errorHandler.handleError(e);
}
return promise;
}
@SuppressWarnings({"unchecked"})
@Override public Promise<Set<String>> listBuckets() {
Promise<Set<String>> promise = new Promise<Set<String>>();
try {
getConnection().write(new RpbRequest(new RpbMessage(MSG_ListBucketsReq, null), promise));
} catch (Exception e) {
errorHandler.handleError(e);
}
return promise;
}
@SuppressWarnings({"unchecked"})
@Override public Promise<BucketProperties> fetchBucket(String bucketName) {
RPB.RpbGetBucketReq.Builder b = RPB.RpbGetBucketReq.newBuilder()
.setBucket(copyFromUtf8(bucketName));
Promise<BucketProperties> promise = new Promise<BucketProperties>();
try {
getConnection().write(new RpbRequest(new RpbMessage(MSG_GetBucketReq, b.build()), promise));
} catch (Exception e) {
errorHandler.handleError(e);
}
return promise;
}
@SuppressWarnings({"unchecked"})
@Override public Promise<Void> updateBucket(String name, BucketProperties bucketProperties) throws IOException {
RPB.RpbBucketProps.Builder pb = RPB.RpbBucketProps.newBuilder();
if (null != bucketProperties.getAllowSiblings()) {
pb.setAllowMult(bucketProperties.getAllowSiblings());
}
if (null != bucketProperties.getNVal()) {
pb.setNVal(bucketProperties.getNVal());
}
RPB.RpbSetBucketReq.Builder b = RPB.RpbSetBucketReq.newBuilder()
.setBucket(copyFromUtf8(name))
.setProps(pb);
Promise<Void> promise = new Promise<Void>();
try {
getConnection().write(new RpbRequest(new RpbMessage(MSG_SetBucketReq, b.build()), promise));
} catch (Exception e) {
errorHandler.handleError(e);
}
return promise;
}
@SuppressWarnings({"unchecked"})
@Override public Promise<Iterable<String>> listKeys(String bucketName) throws IOException {
RPB.RpbListKeysReq.Builder b = RPB.RpbListKeysReq.newBuilder()
.setBucket(copyFromUtf8(bucketName));
Promise<Iterable<String>> promise = new Promise<Iterable<String>>();
try {
getConnection().write(new RpbRequest(new RpbMessage(MSG_ListKeysReq, b.build()), promise));
} catch (Exception e) {
errorHandler.handleError(e);
}
return promise;
}
@Override public Promise<WalkResult> linkWalk(LinkWalkSpec linkWalkSpec) throws IOException {
throw new UnsupportedOperationException("Link walking not yet supported");
}
@SuppressWarnings({"unchecked"})
@Override
public Promise<MapReduceResult> mapReduce(String json) throws IOException, MapReduceTimeoutException {
RPB.RpbMapRedReq.Builder b = RPB.RpbMapRedReq.newBuilder()
.setContentType(copyFromUtf8("application/json"))
.setRequest(copyFromUtf8(json));
Promise<MapReduceResult> promise = new Promise<MapReduceResult>();
try {
getConnection().write(new RpbRequest(new RpbMessage(MSG_MapRedReq, b.build()), promise));
} catch (Exception e) {
errorHandler.handleError(e);
}
return promise;
}
@SuppressWarnings({"unchecked"})
@Override public Promise<byte[]> generateAndSetClientId() throws IOException {
SecureRandom sr;
try {
sr = SecureRandom.getInstance("SHA1PRNG");
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
byte[] data = new byte[6];
sr.nextBytes(data);
String clientId = CharsetUtils.asString(Base64.encodeBase64Chunked(data), CharsetUtils.ISO_8859_1);
RPB.RpbSetClientIdReq.Builder b = RPB.RpbSetClientIdReq.newBuilder()
.setClientId(copyFromUtf8(clientId));
Promise<byte[]> promise = new Promise<byte[]>();
try {
getConnection().write(new RpbRequest(new RpbMessage(MSG_SetClientIdReq, b.build()), promise));
} catch (Exception e) {
errorHandler.handleError(e);
}
return promise;
}
@SuppressWarnings({"unchecked"})
@Override public Promise<Void> setClientId(byte[] clientId) throws IOException {
RPB.RpbSetClientIdReq.Builder b = RPB.RpbSetClientIdReq.newBuilder()
.setClientId(copyFrom(clientId));
Promise<Void> promise = new Promise<Void>();
try {
getConnection().write(new RpbRequest(new RpbMessage(MSG_SetClientIdReq, b.build()), promise));
} catch (Exception e) {
errorHandler.handleError(e);
}
return promise;
}
@SuppressWarnings({"unchecked"})
@Override public Promise<byte[]> getClientId() throws IOException {
Promise<byte[]> promise = new Promise<byte[]>();
try {
getConnection().write(new RpbRequest(new RpbMessage(MSG_GetClientIdReq, null), promise));
} catch (Exception e) {
errorHandler.handleError(e);
}
return promise;
}
@SuppressWarnings({"unchecked"})
@Override public Promise<ServerInfo> getServerInfo() throws IOException {
Promise<ServerInfo> promise = new Promise<ServerInfo>();
try {
getConnection().write(new RpbRequest(new RpbMessage(MSG_GetServerInfoReq, null), promise));
} catch (Exception e) {
errorHandler.handleError(e);
}
return promise;
}
private List<RiakObject> createRiakObjectsFromContent(List<RPB.RpbContent> contents) {
List<RiakObject> robjs = new ArrayList<RiakObject>(contents.size());
for (RPB.RpbContent content : contents) {
RiakObject robj = new RiakObject();
robj.setVtag(content.getVtag());
robj.setContentType(content.getContentType());
robj.setContentEncoding(content.getContentEncoding());
robj.setLastModified(content.getLastMod());
robj.setLastModifiedUsec(content.getLastModUsecs());
robj.setValue(content.getValue());
robjs.add(robj);
}
return robjs;
}
private class PendingRequestFilter extends BaseFilter {
@Override public NextAction handleConnect(FilterChainContext ctx) throws IOException {
pendingRequests.clear();
return ctx.getInvokeAction();
}
@SuppressWarnings({"unchecked"})
@Override public NextAction handleRead(final FilterChainContext ctx) throws IOException {
final RpbMessage<?> msg = ctx.getMessage();
final RpbRequest pending = pendingRequests.poll();
if (null == pending) {
return ctx.getStopAction();
}
if (log.isDebugEnabled()) {
log.debug(String.format("Incoming message: " + msg));
log.debug(String.format("Pending request: " + pending));
// log.debug(String.format("requests=" + requests.get() + ", responses=" + responses.incrementAndGet()));
}
Object response = null;
switch (msg.getCode()) {
case MSG_ErrorResp:
RPB.RpbErrorResp errorResp = (RPB.RpbErrorResp) msg.getMessage();
pending.getPromise().setFailure(new RiakException(errorResp.getErrmsg().toStringUtf8()));
break;
case MSG_DelResp:
response = Void.INSTANCE;
break;
case MSG_GetBucketResp:
RPB.RpbGetBucketResp bucketResp = (RPB.RpbGetBucketResp) msg.getMessage();
if (log.isDebugEnabled())
log.debug(String.format("GetBucket response: %s", bucketResp));
if (bucketResp.hasProps()) {
RPB.RpbBucketProps bucketProps = bucketResp.getProps();
BucketProperties props = new BucketPropertiesBuilder()
.allowSiblings(bucketProps.getAllowMult())
.nVal(bucketProps.getNVal())
.build();
response = props;
}
break;
case MSG_GetClientIdResp:
RPB.RpbGetClientIdResp clientIdResp = (RPB.RpbGetClientIdResp) msg.getMessage();
if (log.isDebugEnabled())
log.debug(String.format("GetClientId response: %s", clientIdResp));
response = clientIdResp.getClientId().toByteArray();
break;
case MSG_GetResp:
RPB.RpbGetResp getResp = ((RPB.RpbGetResp) msg.getMessage());
if (log.isDebugEnabled())
log.debug(String.format("Get response: %s", getResp));
List<RiakObject> robjs = createRiakObjectsFromContent(getResp.getContentList());
response = new RiakResponse(getResp.getVclock().toByteArray(), robjs.toArray(new RiakObject[robjs.size()]));
break;
case MSG_GetServerInfoResp:
RPB.RpbGetServerInfoResp serverInfoResp = (RPB.RpbGetServerInfoResp) msg.getMessage();
if (log.isDebugEnabled())
log.debug(String.format("GetServerInfo response: %s", serverInfoResp));
response = new ServerInfo(serverInfoResp.getNode().toStringUtf8(),
serverInfoResp.getServerVersion().toStringUtf8());
break;
case MSG_ListBucketsResp:
RPB.RpbListBucketsResp listBucketsResp = (RPB.RpbListBucketsResp) msg.getMessage();
if (log.isDebugEnabled())
log.debug(String.format("ListBuckets response: %s", listBucketsResp));
Set<String> buckets = new HashSet<String>();
for (ByteString s : listBucketsResp.getBucketsList()) {
buckets.add(s.toStringUtf8());
}
response = buckets;
break;
case MSG_ListKeysResp:
RPB.RpbListKeysResp listKeysResp = (RPB.RpbListKeysResp) msg.getMessage();
if (log.isDebugEnabled())
log.debug(String.format("ListKeys response: %s", listKeysResp));
Set<String> keys = new HashSet<String>();
for (ByteString s : listKeysResp.getKeysList()) {
keys.add(s.toStringUtf8());
}
response = keys;
break;
case MSG_MapRedResp:
RPB.RpbMapRedResp mapRedResp = (RPB.RpbMapRedResp) msg.getMessage();
if (log.isDebugEnabled())
log.debug(String.format("MapRed response: %s", mapRedResp));
String json = mapRedResp.getResponse().toStringUtf8();
RiakAsyncMapReduceResult mapRedResult = new RiakAsyncMapReduceResult();
mapRedResult.setResultRaw(json);
response = mapRedResult;
break;
case MSG_PingResp:
break;
case MSG_PutResp:
RPB.RpbPutResp putResp = (RPB.RpbPutResp) msg.getMessage();
if (log.isDebugEnabled())
log.debug(String.format("Put response: %s", putResp));
robjs = createRiakObjectsFromContent(putResp.getContentsList());
response = new RiakResponse(putResp.getVclock().toByteArray(), robjs.toArray(new RiakObject[robjs.size()]));
break;
case MSG_SetBucketResp:
response = Void.INSTANCE;
break;
case MSG_SetClientIdResp:
RPB.RpbSetClientIdReq clientIdReq = (RPB.RpbSetClientIdReq) pending.getMessage().getMessage();
response = clientIdReq.getClientId().toByteArray();
break;
}
if (null != response)
pending.getPromise().setResult(response);
if (log.isDebugEnabled()) {
long elapsed = System.currentTimeMillis() - pending.getStart();
String msgClazz = "null";
if (null != pending.getMessage().getMessage()) {
msgClazz = pending.getMessage().getMessage().getClass().getSimpleName();
}
log.debug(String.format("Round-trip time for %s: %s", msgClazz, elapsed));
}
return ctx.getStopAction();
}
@Override public NextAction handleWrite(FilterChainContext ctx) throws IOException {
RpbRequest pending = ctx.getMessage();
pendingRequests.add(pending);
ctx.setMessage(pending.getMessage());
return ctx.getInvokeAction();
}
@Override public void exceptionOccurred(FilterChainContext ctx, Throwable error) {
errorHandler.handleError(error);
}
}
private class RpbRequest<T> {
private Long start = System.currentTimeMillis();
private RpbMessage message;
private Promise<T> promise;
private RpbRequest(RpbMessage message, Promise<T> promise) {
this.message = message;
this.promise = promise;
}
public void setStart() {
start = System.currentTimeMillis();
}
public Long getStart() {
return start;
}
public RpbMessage getMessage() {
return message;
}
public Promise<T> getPromise() {
return promise;
}
@Override public String toString() {
String s = "RpbRequest{" +
"start=" + start +
", message=" + message +
", promise=" + promise + "}";
return s;
}
}
}