/*
* JBoss, Home of Professional Open Source.
* See the COPYRIGHT.txt file distributed with this work for information
* regarding copyright ownership. Some portions may be licensed
* to Red Hat, Inc. under one or more contributor license agreements.
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
* 02110-1301 USA.
*/
package org.teiid.replication.jgroups;
import java.io.IOException;
import java.io.Serializable;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import org.jgroups.Address;
import org.jgroups.Channel;
import org.jgroups.MembershipListener;
import org.jgroups.Message;
import org.jgroups.MessageListener;
import org.jgroups.ReceiverAdapter;
import org.jgroups.View;
import org.jgroups.blocks.MethodCall;
import org.jgroups.blocks.MethodLookup;
import org.jgroups.blocks.RequestOptions;
import org.jgroups.blocks.ResponseMode;
import org.jgroups.blocks.RpcDispatcher;
import org.jgroups.util.Promise;
import org.jgroups.util.Rsp;
import org.jgroups.util.RspList;
import org.teiid.Replicated;
import org.teiid.Replicated.ReplicationMode;
import org.teiid.core.TeiidRuntimeException;
import org.teiid.logging.LogConstants;
import org.teiid.logging.LogManager;
import org.teiid.query.ObjectReplicator;
import org.teiid.query.ReplicatedObject;
import org.teiid.runtime.RuntimePlugin;
@SuppressWarnings("unchecked")
public class JGroupsObjectReplicator implements ObjectReplicator, Serializable {
private static final int IO_TIMEOUT = 15000;
private static final int STATE_TIMEOUT = 5000;
private final class ReplicatorRpcDispatcher<S> extends RpcDispatcher {
private final S object;
private boolean initialized;
private final HashMap<Method, Short> methodMap;
private final ArrayList<Method> methodList;
Map<List<?>, JGroupsInputStream> inputStreams = new ConcurrentHashMap<List<?>, JGroupsInputStream>();
private ReplicatorRpcDispatcher(Channel channel, MessageListener l,
MembershipListener l2, Object serverObj, S object,
HashMap<Method, Short> methodMap, ArrayList<Method> methodList) {
super(channel, l, l2, serverObj);
this.object = object;
this.methodMap = methodMap;
this.methodList = methodList;
}
@Override
public Object handle(Message req) {
Object body=null;
if(req == null || req.getLength() == 0) {
if(log.isErrorEnabled()) log.error("message or message buffer is null"); //$NON-NLS-1$
return null;
}
try {
body=req_marshaller != null?
req_marshaller.objectFromBuffer(req.getBuffer(), req.getOffset(), req.getLength())
: req.getObject(getClass().getClassLoader());
}
catch(Throwable e) {
if(log.isErrorEnabled()) log.error("exception marshalling object", e); //$NON-NLS-1$
return e;
}
if(!(body instanceof MethodCall)) {
if(log.isErrorEnabled()) log.error("message does not contain a MethodCall object"); //$NON-NLS-1$
// create an exception to represent this and return it
return new IllegalArgumentException("message does not contain a MethodCall object") ; //$NON-NLS-1$
}
final MethodCall method_call=(MethodCall)body;
try {
if(log.isTraceEnabled())
log.trace("[sender=" + req.getSrc() + "], method_call: " + method_call); //$NON-NLS-1$ //$NON-NLS-2$
if (method_call.getId() >= methodList.size() - 5 && req.getSrc().equals(local_addr)) {
return null;
}
if (method_call.getId() >= methodList.size() - 3) {
Serializable address = req.getSrc();
Serializable stateId = (Serializable)method_call.getArgs()[0];
List<?> key = Arrays.asList(stateId, address);
JGroupsInputStream is = inputStreams.get(key);
if (method_call.getId() == methodList.size() - 3) {
LogManager.logTrace(LogConstants.CTX_RUNTIME, object, "create state", stateId); //$NON-NLS-1$
if (is != null) {
is.receive(null);
}
is = new JGroupsInputStream(IO_TIMEOUT);
this.inputStreams.put(key, is);
executor.execute(new StreamingRunner(object, stateId, is, null));
} else if (method_call.getId() == methodList.size() - 2) {
LogManager.logTrace(LogConstants.CTX_RUNTIME, object, "building state", stateId); //$NON-NLS-1$
if (is != null) {
is.receive((byte[])method_call.getArgs()[1]);
}
} else if (method_call.getId() == methodList.size() - 1) {
LogManager.logTrace(LogConstants.CTX_RUNTIME, object, "finished state", stateId); //$NON-NLS-1$
if (is != null) {
is.receive(null);
}
this.inputStreams.remove(key);
}
return null;
} else if (method_call.getId() == methodList.size() - 5) {
//hasState
ReplicatedObject ro = (ReplicatedObject)object;
Serializable stateId = (Serializable)method_call.getArgs()[0];
if (stateId == null) {
synchronized (this) {
if (initialized) {
return Boolean.TRUE;
}
return null;
}
}
if (ro.hasState(stateId)) {
return Boolean.TRUE;
}
return null;
} else if (method_call.getId() == methodList.size() - 4) {
//sendState
ReplicatedObject ro = (ReplicatedObject)object;
String stateId = (String)method_call.getArgs()[0];
Address dest = (Address)method_call.getArgs()[1];
JGroupsOutputStream oStream = new JGroupsOutputStream(this, Arrays.asList(dest), stateId, (short)(methodMap.size() - 3), false);
try {
if (stateId == null) {
ro.getState(oStream);
} else {
ro.getState(stateId, oStream);
}
} finally {
oStream.close();
}
LogManager.logTrace(LogConstants.CTX_RUNTIME, object, "sent state", stateId); //$NON-NLS-1$
return null;
}
Method m=method_lookup.findMethod(method_call.getId());
if(m == null)
throw new Exception("no method found for " + method_call.getId()); //$NON-NLS-1$
method_call.setMethod(m);
return method_call.invoke(server_obj);
}
catch(Throwable x) {
return x;
}
}
}
private static final long serialVersionUID = -6851804958313095166L;
private static final String HAS_STATE = "hasState"; //$NON-NLS-1$
private static final String SEND_STATE = "sendState"; //$NON-NLS-1$
private static final String CREATE_STATE = "createState"; //$NON-NLS-1$
private static final String BUILD_STATE = "buildState"; //$NON-NLS-1$
private static final String FINISH_STATE = "finishState"; //$NON-NLS-1$
private final static class StreamingRunner implements Runnable {
private final Object object;
private final Serializable stateId;
private final JGroupsInputStream is;
private Promise<Boolean> promise;
private StreamingRunner(Object object, Serializable stateId, JGroupsInputStream is, Promise<Boolean> promise) {
this.object = object;
this.stateId = stateId;
this.is = is;
this.promise = promise;
}
@Override
public void run() {
try {
if (stateId == null) {
((ReplicatedObject<?>)object).setState(is);
} else {
((ReplicatedObject)object).setState(stateId, is);
}
if (promise != null) {
promise.setResult(Boolean.TRUE);
}
LogManager.logDetail(LogConstants.CTX_RUNTIME, "state set", stateId); //$NON-NLS-1$
} catch (Exception e) {
if (promise != null) {
promise.setResult(Boolean.FALSE);
}
LogManager.logError(LogConstants.CTX_RUNTIME, e, RuntimePlugin.Util.gs(RuntimePlugin.Event.TEIID40101, stateId));
} finally {
is.close();
}
}
}
private final class ReplicatedInvocationHandler<S> extends ReceiverAdapter implements
InvocationHandler, Serializable {
private static final int PULL_RETRIES = 3;
private static final long serialVersionUID = -2943462899945966103L;
private final S object;
private transient ReplicatorRpcDispatcher<S> disp;
private final HashMap<Method, Short> methodMap;
protected List<Address> remoteMembers = Collections.synchronizedList(new ArrayList<Address>());
private Map<Serializable, Promise<Boolean>> loadingStates = new HashMap<Serializable, Promise<Boolean>>();
private ReplicatedInvocationHandler(S object,HashMap<Method, Short> methodMap) {
this.object = object;
this.methodMap = methodMap;
}
List<Address> getRemoteMembersCopy() {
synchronized (remoteMembers) {
return new ArrayList<Address>(remoteMembers);
}
}
public void setDisp(ReplicatorRpcDispatcher<S> disp) {
this.disp = disp;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args)
throws Throwable {
Short methodNum = methodMap.get(method);
if (methodNum == null || remoteMembers.isEmpty()) {
if (methodNum != null) {
Replicated annotation = method.getAnnotation(Replicated.class);
if (annotation != null && annotation.remoteOnly()) {
return null;
}
}
try {
return method.invoke(object, args);
} catch (InvocationTargetException e) {
throw e.getCause();
}
}
try {
Replicated annotation = method.getAnnotation(Replicated.class);
if (annotation.replicateState() != ReplicationMode.NONE) {
return handleReplicateState(method, args, annotation);
}
MethodCall call=new MethodCall(methodNum, args);
List<Address> dests = null;
if (annotation.remoteOnly()) {
dests = getRemoteMembersCopy();
if (dests.isEmpty()) {
return null;
}
}
RspList<Object> responses = disp.callRemoteMethods(dests, call, new RequestOptions().setMode(annotation.asynch()?ResponseMode.GET_NONE:ResponseMode.GET_ALL).setTimeout(annotation.timeout()).setAnycasting(dests != null));
if (annotation.asynch()) {
return null;
}
List<Object> results = responses.getResults();
if (method.getReturnType() == boolean.class) {
for (Object o : results) {
if (!Boolean.TRUE.equals(o)) {
return false;
}
}
return true;
} else if (method.getReturnType() == Collection.class) {
ArrayList<Object> result = new ArrayList<Object>();
for (Object o : results) {
result.addAll((Collection)o);
}
return results;
}
return null;
} catch(Exception e) {
throw new RuntimeException(method + " " + args + " failed", e); //$NON-NLS-1$ //$NON-NLS-2$
}
}
protected Address whereIsState(Serializable stateId, long timeout) throws Exception {
if (remoteMembers.isEmpty()) {
return null;
}
RspList<Boolean> resp = this.disp.callRemoteMethods(getRemoteMembersCopy(), new MethodCall((short)(methodMap.size() - 5), new Object[]{stateId}), new RequestOptions(ResponseMode.GET_ALL, timeout));
Collection<Rsp<Boolean>> values = resp.values();
Rsp<Boolean> rsp = null;
for (Rsp<Boolean> response : values) {
if (Boolean.TRUE.equals(response.getValue())) {
rsp = response;
break;
}
}
if (rsp == null) {
return null;
}
return rsp.getSender();
}
private Object handleReplicateState(Method method, Object[] args,
Replicated annotation) throws IllegalAccessException,
Throwable, IOException, IllegalStateException, Exception {
Object result = null;
try {
result = method.invoke(object, args);
} catch (InvocationTargetException e) {
throw e.getCause();
}
ReplicatedObject ro = (ReplicatedObject)object;
Serializable stateId = (Serializable)args[0];
if (annotation.replicateState() == ReplicationMode.PUSH) {
if (!remoteMembers.isEmpty()) {
LogManager.logDetail(LogConstants.CTX_RUNTIME, object, "replicating state", stateId); //$NON-NLS-1$
JGroupsOutputStream oStream = new JGroupsOutputStream(disp, null, stateId, (short)(methodMap.size() - 3), true);
try {
ro.getState(stateId, oStream);
} finally {
oStream.close();
}
LogManager.logTrace(LogConstants.CTX_RUNTIME, object, "sent state", stateId); //$NON-NLS-1$
}
return result;
}
if (result != null) {
return result;
}
long timeout = annotation.timeout();
return pullState(method, args, stateId, timeout, timeout);
}
/**
* Pull the remote state. The method and args are optional
* to determine if the state has been made available.
*/
Object pullState(Method method, Object[] args, Serializable stateId,
long timeout, long stateDetectTimeout) throws Throwable {
Object result = null;
for (int i = 0; i < PULL_RETRIES; i++) {
Promise<Boolean> p = null;
boolean wait = true;
synchronized (loadingStates) {
p = loadingStates.get(stateId);
if (p == null) {
wait = false;
if (method != null) {
try {
result = method.invoke(object, args);
} catch (InvocationTargetException e) {
throw e.getCause();
}
if (result != null) {
return result;
}
}
p = new Promise<Boolean>();
loadingStates.put(stateId, p);
}
}
if (wait) {
p.getResult(timeout);
continue;
}
try {
LogManager.logDetail(LogConstants.CTX_RUNTIME, object, "pulling state", stateId); //$NON-NLS-1$
Address addr = whereIsState(stateId, stateDetectTimeout);
if (addr == null) {
LogManager.logDetail(LogConstants.CTX_RUNTIME, object, "timeout exceeded or first member"); //$NON-NLS-1$
break;
}
JGroupsInputStream is = new JGroupsInputStream(IO_TIMEOUT);
StreamingRunner runner = new StreamingRunner(object, stateId, is, p);
List<?> key = Arrays.asList(stateId, addr);
disp.inputStreams.put(key, is);
executor.execute(runner);
this.disp.callRemoteMethod(addr, new MethodCall((short)(methodMap.size() - 4), stateId, this.disp.getChannel().getAddress()), new RequestOptions(ResponseMode.GET_NONE, 0).setAnycasting(true));
Boolean fetched = p.getResult(timeout);
if (fetched != null) {
if (fetched) {
LogManager.logDetail(LogConstants.CTX_RUNTIME, object, "pulled state", stateId); //$NON-NLS-1$
if (method !=null) {
try {
result = method.invoke(object, args);
} catch (InvocationTargetException e) {
throw e.getCause();
}
if (result != null) {
return result;
}
}
break;
}
LogManager.logWarning(LogConstants.CTX_RUNTIME, RuntimePlugin.Util.gs(RuntimePlugin.Event.TEIID40102, object, stateId));
} else {
LogManager.logWarning(LogConstants.CTX_RUNTIME, RuntimePlugin.Util.gs(RuntimePlugin.Event.TEIID40103, object, stateId));
}
} finally {
synchronized (loadingStates) {
loadingStates.remove(stateId);
}
}
}
return null; //could not fetch the remote state
}
@Override
public void viewAccepted(View newView) {
if (newView.getMembers() != null) {
synchronized (remoteMembers) {
remoteMembers.removeAll(newView.getMembers());
if (object instanceof ReplicatedObject<?> && !remoteMembers.isEmpty()) {
HashSet<Serializable> dropped = new HashSet<Serializable>();
for (Address address : remoteMembers) {
dropped.add(address);
}
((ReplicatedObject<?>)object).droppedMembers(dropped);
}
remoteMembers.clear();
remoteMembers.addAll(newView.getMembers());
remoteMembers.remove(this.disp.getChannel().getAddress());
}
}
}
}
private interface Streaming {
void sendState(Serializable id, Address dest);
void createState(Serializable id);
void buildState(Serializable id, byte[] bytes);
void finishState(Serializable id);
}
//TODO: this should be configurable, or use a common executor
private transient Executor executor;
private transient ChannelFactory channelFactory;
public JGroupsObjectReplicator(ChannelFactory channelFactory, Executor executor) {
this.channelFactory = channelFactory;
this.executor = executor;
}
public void stop(Object object) {
if (object == null || !Proxy.isProxyClass(object.getClass())) {
return;
}
ReplicatedInvocationHandler<?> handler = (ReplicatedInvocationHandler<?>) Proxy.getInvocationHandler(object);
Channel c = handler.disp.getChannel();
handler.disp.stop();
c.disconnect();
c.close();
}
@Override
public <T, S> T replicate(String mux_id,
Class<T> iface, final S object, long startTimeout) throws Exception {
Channel channel = channelFactory.createChannel(mux_id);
// To keep the order of methods same at all the nodes.
TreeMap<String, Method> methods = new TreeMap<String, Method>();
for (Method method : iface.getMethods()) {
if (method.getAnnotation(Replicated.class) == null) {
continue;
}
methods.put(method.toGenericString(), method);
}
final HashMap<Method, Short> methodMap = new HashMap<Method, Short>();
final ArrayList<Method> methodList = new ArrayList<Method>();
for (String method : methods.keySet()) {
methodList.add(methods.get(method));
methodMap.put(methods.get(method), (short)(methodList.size() - 1));
}
Method hasState = ReplicatedObject.class.getMethod(HAS_STATE, new Class<?>[] {Serializable.class});
methodList.add(hasState);
methodMap.put(hasState, (short)(methodList.size() - 1));
Method sendState = JGroupsObjectReplicator.Streaming.class.getMethod(SEND_STATE, new Class<?>[] {Serializable.class, Address.class});
methodList.add(sendState);
methodMap.put(sendState, (short)(methodList.size() - 1));
//add in streaming methods
Method createState = JGroupsObjectReplicator.Streaming.class.getMethod(CREATE_STATE, new Class<?>[] {Serializable.class});
methodList.add(createState);
methodMap.put(createState, (short)(methodList.size() - 1));
Method buildState = JGroupsObjectReplicator.Streaming.class.getMethod(BUILD_STATE, new Class<?>[] {Serializable.class, byte[].class});
methodList.add(buildState);
methodMap.put(buildState, (short)(methodList.size() - 1));
Method finishState = JGroupsObjectReplicator.Streaming.class.getMethod(FINISH_STATE, new Class<?>[] {Serializable.class});
methodList.add(finishState);
methodMap.put(finishState, (short)(methodList.size() - 1));
ReplicatedInvocationHandler<S> proxy = new ReplicatedInvocationHandler<S>(object, methodMap);
/*
* TODO: could have an object implement streaming
* Override the normal handle method to support streaming
*/
ReplicatorRpcDispatcher disp = new ReplicatorRpcDispatcher<S>(channel, proxy, proxy, object, object, methodMap, methodList);
proxy.setDisp(disp);
disp.setMethodLookup(new MethodLookup() {
public Method findMethod(short id) {
return methodList.get(id);
}
});
T replicatedProxy = (T) Proxy.newProxyInstance(Thread.currentThread().getContextClassLoader(), new Class[] {iface}, proxy);
boolean success = false;
try {
channel.connect(mux_id);
if (object instanceof ReplicatedObject) {
((ReplicatedObject)object).setAddress(channel.getAddress());
proxy.pullState(null, null, null, startTimeout, startTimeout != 0?STATE_TIMEOUT:0);
}
success = true;
return replicatedProxy;
} catch (Throwable e) {
if (e instanceof Exception) {
throw (Exception)e;
}
throw new TeiidRuntimeException(RuntimePlugin.Event.TEIID40104, e);
} finally {
if (!success) {
channel.close();
} else {
synchronized (disp) {
//mark as initialized so that state can be pulled if needed
disp.initialized = true;
}
}
}
}
}