/*
* RHQ Management Platform
* Copyright (C) 2005-2008 Red Hat, Inc.
* All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation version 2 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
package org.rhq.enterprise.server.remote;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.management.MBeanServer;
import javax.naming.InitialContext;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jboss.remoting.InvocationRequest;
import org.jboss.remoting.ServerInvocationHandler;
import org.jboss.remoting.ServerInvoker;
import org.jboss.remoting.callback.InvokerCallbackHandler;
import org.jboss.remoting.invocation.NameBasedInvocation;
import org.rhq.core.domain.server.ExternalizableStrategy;
import org.rhq.core.util.exception.WrappedRemotingException;
import org.rhq.enterprise.server.safeinvoker.HibernateDetachUtility;
/**
* Handle remote invocations. Note that we perform only invocations defined in the remote interfaces.
* Although, we execute only locals to bypass the serialization performed by a remote invocation. Even
* though this handler is co-located, for remotes, remoting will serialize the return data immediately.
* This is bad for us because since we return domain objects we ned to scrub the data, removing
* hibernate proxies (see {@link HibernateDetachUtility}.
*
* @author Greg Hinkle
* @autor Jay Shaughnessy
*/
public class RemoteSafeInvocationHandler implements ServerInvocationHandler {
private static final Log log = LogFactory.getLog(RemoteSafeInvocationHandler.class);
private static final Map<String, Class<?>> PRIMITIVE_CLASSES;
private static final ConcurrentHashMap<Class<?>, String> LOCAL_JNDI_NAMES;
private static final ConcurrentHashMap<Class<?>, String> REMOTE_JNDI_NAMES;
static {
PRIMITIVE_CLASSES = new HashMap<String, Class<?>>();
PRIMITIVE_CLASSES.put(Short.TYPE.getName(), Short.TYPE);
PRIMITIVE_CLASSES.put(Integer.TYPE.getName(), Integer.TYPE);
PRIMITIVE_CLASSES.put(Long.TYPE.getName(), Long.TYPE);
PRIMITIVE_CLASSES.put(Float.TYPE.getName(), Float.TYPE);
PRIMITIVE_CLASSES.put(Double.TYPE.getName(), Double.TYPE);
PRIMITIVE_CLASSES.put(Boolean.TYPE.getName(), Boolean.TYPE);
PRIMITIVE_CLASSES.put(Character.TYPE.getName(), Character.TYPE);
PRIMITIVE_CLASSES.put(Byte.TYPE.getName(), Byte.TYPE);
LOCAL_JNDI_NAMES = new ConcurrentHashMap<Class<?>, String>();
REMOTE_JNDI_NAMES = new ConcurrentHashMap<Class<?>, String>();
}
private RemoteSafeInvocationHandlerMetrics metrics = new RemoteSafeInvocationHandlerMetrics();
public Object invoke(InvocationRequest invocationRequest) throws Throwable {
if (invocationRequest == null) {
throw new IllegalArgumentException("InvocationRequest was null.");
}
String methodName = null;
boolean successful = false; // we will flip this to true when we know we were successful
Object result = null;
long time = System.currentTimeMillis();
try {
InitialContext ic = new InitialContext();
NameBasedInvocation nbi = ((NameBasedInvocation) invocationRequest.getParameter());
if (null == nbi) {
throw new IllegalArgumentException("InvocationRequest did not supply method.");
}
methodName = nbi.getMethodName();
String[] methodInfo = methodName.split(":");
Class<?> remoteClass = getClass(methodInfo[0]);
String[] signature = nbi.getSignature();
int signatureLength = signature.length;
Class<?>[] sig = new Class[signatureLength];
for (int i = 0; i < signatureLength; i++) {
sig[i] = getClass(signature[i]);
}
// make sure the remote method is defined to ensure remote clients don't access locals
String jndiName = getRemoteJNDIName(remoteClass);
Object target = ic.lookup(jndiName);
Method m = target.getClass().getMethod(methodInfo[1], sig);
// switch to the local
jndiName = getLocalJNDIName(remoteClass);
target = ic.lookup(jndiName);
m = target.getClass().getMethod(methodInfo[1], sig);
result = m.invoke(target, nbi.getParameters());
successful = true;
} catch (InvocationTargetException e) {
log.error("Failed to invoke remote request", e);
return new WrappedRemotingException(e.getTargetException());
} catch (Exception e) {
log.error("Failed to invoke remote request", e);
return new WrappedRemotingException(e);
} finally {
if (result != null) {
// set the strategy guiding how the return information is serialized
ExternalizableStrategy.setStrategy(ExternalizableStrategy.Subsystem.REFLECTIVE_SERIALIZATION);
// scrub the return data if Hibernate proxies
try {
HibernateDetachUtility.nullOutUninitializedFields(result,
HibernateDetachUtility.SerializationType.SERIALIZATION);
} catch (Exception e) {
log.error("Failed to null out uninitialized fields", e);
this.metrics.addData(methodName, System.currentTimeMillis() - time, false);
return new WrappedRemotingException(e);
}
}
// want to calculate this after the hibernate util so we take that into account too
long executionTime = System.currentTimeMillis() - time;
this.metrics.addData(methodName, executionTime, successful);
if (log.isDebugEnabled()) {
log.debug("Remote request [" + methodName + "] execution time (ms): " + executionTime);
}
}
return result;
}
private static <T> String getLocalJNDIName(Class<?> remoteClass) {
String jndiName = LOCAL_JNDI_NAMES.get(remoteClass);
if (jndiName == null) {
jndiName = "java:global/rhq/rhq-server/" + remoteClass.getSimpleName().replaceFirst("Remote$", "Bean")
+ "!" + remoteClass.getName().replaceFirst("Remote$", "Local");
LOCAL_JNDI_NAMES.put(remoteClass, jndiName);
}
return jndiName;
}
private static <T> String getRemoteJNDIName(Class<?> remoteClass) {
String jndiName = REMOTE_JNDI_NAMES.get(remoteClass);
if (jndiName == null) {
jndiName = "java:global/rhq/rhq-server/" + remoteClass.getSimpleName().replaceFirst("Remote$", "Bean")
+ "!" + remoteClass.getName();
REMOTE_JNDI_NAMES.put(remoteClass, jndiName);
}
return jndiName;
}
private static Class<?> getClass(String name) throws ClassNotFoundException {
// TODO GH: Doesn't support arrays
if (PRIMITIVE_CLASSES.containsKey(name)) {
return PRIMITIVE_CLASSES.get(name);
} else {
return Class.forName(name);
}
}
/**
* Registers the MBean used to monitor the remote API processing.
*
* @param mbs the MBeanServer where the metrics MBean should be registered
*/
public void registerMetricsMBean(MBeanServer mbs) {
try {
mbs.registerMBean(this.metrics, RemoteSafeInvocationHandlerMetricsMBean.OBJECTNAME_METRICS);
} catch (Exception e) {
log.warn("Failed to register the metrics object, will not be able to monitor remote API: " + e);
}
}
/**
* Unregisters the MBean that was used to monitor the remote API processing.
*
* @param mbs the MBeanServer where the metrics MBean is registered
*/
public void unregisterMetricsMBean(MBeanServer mbs) {
try {
mbs.unregisterMBean(RemoteSafeInvocationHandlerMetricsMBean.OBJECTNAME_METRICS);
} catch (Exception e) {
log.warn("Failed to unregister the metrics object: " + e);
}
}
public void addListener(InvokerCallbackHandler handler) {
}
public void removeListener(InvokerCallbackHandler handler) {
}
public void setInvoker(ServerInvoker invoker) {
}
public void setMBeanServer(MBeanServer mbs) {
}
}