/* * JBoss, Home of Professional Open Source. * Copyright 2008, Red Hat Middleware LLC, and individual contributors * as indicated by the @author tags. See the copyright.txt file in the * distribution for a full listing of individual contributors. * * This is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 of * the License, or (at your option) any later version. * * This software is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this software; if not, write to the Free * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA * 02110-1301 USA, or see the FSF site: http://www.fsf.org. */ package org.jboss.ejb.plugins.local; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.lang.reflect.Constructor; import java.rmi.AccessException; import java.rmi.NoSuchObjectException; import java.security.Principal; import java.security.PrivilegedAction; import java.security.AccessController; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import javax.ejb.AccessLocalException; import javax.ejb.EJBLocalHome; import javax.ejb.EJBLocalObject; import javax.ejb.NoSuchObjectLocalException; import javax.ejb.TransactionRequiredLocalException; import javax.ejb.TransactionRolledbackLocalException; import javax.naming.Context; import javax.naming.InitialContext; import javax.transaction.Transaction; import javax.transaction.TransactionManager; import javax.transaction.TransactionRequiredException; import javax.transaction.TransactionRolledbackException; import org.jboss.ejb.Container; import org.jboss.ejb.EJBProxyFactoryContainer; import org.jboss.ejb.LocalProxyFactory; import org.jboss.invocation.InvocationType; import org.jboss.invocation.MarshalledInvocation; import org.jboss.invocation.LocalEJBInvocation; import org.jboss.logging.Logger; import org.jboss.metadata.BeanMetaData; import org.jboss.naming.Util; import org.jboss.security.SecurityContext; import org.jboss.security.SecurityContextAssociation; import org.jboss.util.NestedRuntimeException; import org.jboss.tm.TransactionLocal; /** * The LocalProxyFactory implementation that handles local ejb interface * proxies. * * @author <a href="mailto:docodan@mvcsoft.com">Daniel OConnor</a> * @author <a href="mailto:scott.stark@jboss.org">Scott Stark</a> * @author <a href="mailto:dain@daingroup.com">Dain Sundstrom</a> * @author <a href="mailto:alex@jboss.org">Alexey Loubyansky</a> * @author Anil.Saldhana@redhat.com * $Revision: 81030 $ */ public class BaseLocalProxyFactory implements LocalProxyFactory { // Attributes ---------------------------------------------------- protected static Logger log = Logger.getLogger(BaseLocalProxyFactory.class); /** * A map of the BaseLocalProxyFactory instances keyed by localJndiName */ protected static Map invokerMap = Collections.synchronizedMap(new HashMap()); protected Container container; /** * The JNDI name of the local home interface binding */ protected String localJndiName; protected TransactionManager transactionManager; // The home can be one. protected EJBLocalHome home; // The Stateless Object can be one. protected EJBLocalObject statelessObject; protected Map beanMethodInvokerMap; protected Map homeMethodInvokerMap; protected Class localHomeClass; protected Class localClass; protected Constructor proxyClassConstructor; private final TransactionLocal cache = new TransactionLocal() { protected Object initialValue() { return new HashMap(); } }; // Static -------------------------------------------------------- // Constructors -------------------------------------------------- // Public -------------------------------------------------------- // ContainerService implementation ------------------------------- public void setContainer(Container con) { this.container = con; } public void create() throws Exception { BeanMetaData metaData = container.getBeanMetaData(); localJndiName = metaData.getLocalJndiName(); } public void start() throws Exception { BeanMetaData metaData = container.getBeanMetaData(); EJBProxyFactoryContainer invokerContainer = (EJBProxyFactoryContainer) container; localHomeClass = invokerContainer.getLocalHomeClass(); localClass = invokerContainer.getLocalClass(); if(localHomeClass == null || localClass == null) { log.debug(metaData.getEjbName() + " cannot be Bound, doesn't " + "have local and local home interfaces"); return; } // this is faster than newProxyInstance Class[] intfs = {localClass}; Class proxyClass = Proxy.getProxyClass(ClassLoaderAction.UTIL.get(localClass), intfs); final Class[] constructorParams = {InvocationHandler.class}; proxyClassConstructor = proxyClass.getConstructor(constructorParams); Context iniCtx = new InitialContext(); String beanName = metaData.getEjbName(); // Set the transaction manager and transaction propagation // context factory of the GenericProxy class transactionManager = (TransactionManager) iniCtx.lookup("java:/TransactionManager"); // Create method mappings for container invoker Method[] methods = localClass.getMethods(); beanMethodInvokerMap = new HashMap(); for(int i = 0; i < methods.length; i++) { long hash = MarshalledInvocation.calculateHash(methods[i]); beanMethodInvokerMap.put(new Long(hash), methods[i]); } methods = localHomeClass.getMethods(); homeMethodInvokerMap = new HashMap(); for(int i = 0; i < methods.length; i++) { long hash = MarshalledInvocation.calculateHash(methods[i]); homeMethodInvokerMap.put(new Long(hash), methods[i]); } // bind that referance to my name Util.rebind(iniCtx, localJndiName, getEJBLocalHome()); invokerMap.put(localJndiName, this); log.info("Bound EJB LocalHome '" + beanName + "' to jndi '" + localJndiName + "'"); } public void stop() { // Clean up the home proxy binding try { if(invokerMap.remove(localJndiName) == this) { log.info("Unbind EJB LocalHome '" + container.getBeanMetaData().getEjbName() + "' from jndi '" + localJndiName + "'"); InitialContext ctx = new InitialContext(); ctx.unbind(localJndiName); } } catch(Exception ignore) { } } public void destroy() { if(beanMethodInvokerMap != null) { beanMethodInvokerMap.clear(); } if(homeMethodInvokerMap != null) { homeMethodInvokerMap.clear(); } MarshalledInvocation.removeHashes(localHomeClass); MarshalledInvocation.removeHashes(localClass); container = null; } public Constructor getProxyClassConstructor() { if(proxyClassConstructor == null) { } return proxyClassConstructor; } // EJBProxyFactory implementation ------------------------------- public EJBLocalHome getEJBLocalHome() { if(home == null) { EJBProxyFactoryContainer cic = (EJBProxyFactoryContainer) container; InvocationHandler handler = new LocalHomeProxy(localJndiName, this); ClassLoader loader = ClassLoaderAction.UTIL.get(cic.getLocalHomeClass()); Class[] interfaces = {cic.getLocalHomeClass()}; home = (EJBLocalHome) Proxy.newProxyInstance(loader, interfaces, handler); } return home; } public EJBLocalObject getStatelessSessionEJBLocalObject() { if(statelessObject == null) { EJBProxyFactoryContainer cic = (EJBProxyFactoryContainer) container; InvocationHandler handler = new StatelessSessionProxy(localJndiName, this); ClassLoader loader = ClassLoaderAction.UTIL.get(cic.getLocalClass()); Class[] interfaces = {cic.getLocalClass()}; statelessObject = (EJBLocalObject) Proxy.newProxyInstance(loader, interfaces, handler); } return statelessObject; } public EJBLocalObject getStatefulSessionEJBLocalObject(Object id) { InvocationHandler handler = new StatefulSessionProxy(localJndiName, id, this); try { return (EJBLocalObject) proxyClassConstructor.newInstance(new Object[]{handler}); } catch(Exception ex) { throw new NestedRuntimeException(ex); } } public Object getEntityEJBObject(Object id) { return getEntityEJBLocalObject(id); } public EJBLocalObject getEntityEJBLocalObject(Object id, boolean create) { EJBLocalObject result = null; if(id != null) { final Transaction tx = cache.getTransaction(); if(tx == null) { result = createEJBLocalObject(id); } else { Map map = (Map) cache.get(tx); if(create) { result = createEJBLocalObject(id); map.put(id, result); } else { result = (EJBLocalObject) map.get(id); if(result == null) { result = createEJBLocalObject(id); map.put(id, result); } } } } return result; } public EJBLocalObject getEntityEJBLocalObject(Object id) { return getEntityEJBLocalObject(id, false); } public Collection getEntityLocalCollection(Collection ids) { ArrayList list = new ArrayList(ids.size()); Iterator iter = ids.iterator(); while(iter.hasNext()) { final Object nextId = iter.next(); list.add(getEntityEJBLocalObject(nextId)); } return list; } /** * Invoke a Home interface method. */ public Object invokeHome(Method m, Object[] args) throws Exception { // Set the right context classloader ClassLoader oldCl = TCLAction.UTIL.getContextClassLoader(); boolean setCl = !oldCl.equals(container.getClassLoader()); if(setCl) { TCLAction.UTIL.setContextClassLoader(container.getClassLoader()); } container.pushENC(); SecurityActions sa = SecurityActions.UTIL.getSecurityActions(); try { LocalEJBInvocation invocation = new LocalEJBInvocation(null, m, args, getTransaction(), sa.getPrincipal(), sa.getCredential()); invocation.setType(InvocationType.LOCALHOME); return container.invoke(invocation); } catch(AccessException ae) { log.trace(ae); throw new AccessLocalException(ae.getMessage(), ae); } catch(NoSuchObjectException nsoe) { throw new NoSuchObjectLocalException(nsoe.getMessage(), nsoe); } catch(TransactionRequiredException tre) { throw new TransactionRequiredLocalException(tre.getMessage()); } catch(TransactionRolledbackException trbe) { throw new TransactionRolledbackLocalException(trbe.getMessage(), trbe); } finally { container.popENC(); if(setCl) { TCLAction.UTIL.setContextClassLoader(oldCl); } } } public String getJndiName() { return localJndiName; } /** * Return the transaction associated with the current thread. * Returns <code>null</code> if the transaction manager was never * set, or if no transaction is associated with the current thread. */ Transaction getTransaction() throws javax.transaction.SystemException { if(transactionManager == null) { return null; } return transactionManager.getTransaction(); } /** * Invoke a local interface method. */ public Object invoke(Object id, Method m, Object[] args) throws Exception { // Set the right context classloader ClassLoader oldCl = TCLAction.UTIL.getContextClassLoader(); boolean setCl = !oldCl.equals(container.getClassLoader()); if(setCl) { TCLAction.UTIL.setContextClassLoader(container.getClassLoader()); } container.pushENC(); SecurityActions sa = SecurityActions.UTIL.getSecurityActions(); try { LocalEJBInvocation invocation = new LocalEJBInvocation(id, m, args, getTransaction(), sa.getPrincipal(), sa.getCredential()); invocation.setType(InvocationType.LOCAL); return container.invoke(invocation); } catch(AccessException ae) { log.trace(ae); throw new AccessLocalException(ae.getMessage(), ae); } catch(NoSuchObjectException nsoe) { throw new NoSuchObjectLocalException(nsoe.getMessage(), nsoe); } catch(TransactionRequiredException tre) { throw new TransactionRequiredLocalException(tre.getMessage()); } catch(TransactionRolledbackException trbe) { throw new TransactionRolledbackLocalException(trbe.getMessage(), trbe); } finally { container.popENC(); if(setCl) { TCLAction.UTIL.setContextClassLoader(oldCl); } } } private EJBLocalObject createEJBLocalObject(Object id) { InvocationHandler handler = new EntityProxy(localJndiName, id, this); try { return (EJBLocalObject) proxyClassConstructor.newInstance(new Object[]{handler}); } catch(Exception ex) { throw new NestedRuntimeException(ex); } } interface ClassLoaderAction { class UTIL { static ClassLoaderAction getClassLoaderAction() { return System.getSecurityManager() == null ? NON_PRIVILEGED : PRIVILEGED; } static ClassLoader get(Class clazz) { return getClassLoaderAction().get(clazz); } } ClassLoaderAction PRIVILEGED = new ClassLoaderAction() { public ClassLoader get(final Class clazz) { return (ClassLoader)AccessController.doPrivileged( new PrivilegedAction() { public Object run() { return clazz.getClassLoader(); } } ); } }; ClassLoaderAction NON_PRIVILEGED = new ClassLoaderAction() { public ClassLoader get(Class clazz) { return clazz.getClassLoader(); } }; ClassLoader get(Class clazz); } interface SecurityActions { class UTIL { static SecurityActions getSecurityActions() { return System.getSecurityManager() == null ? NON_PRIVILEGED : PRIVILEGED; } } SecurityActions NON_PRIVILEGED = new SecurityActions() { public Principal getPrincipal() { SecurityContext sc = getSecurityContext(); if(sc == null) return null; return sc.getUtil().getUserPrincipal(); } public Object getCredential() { SecurityContext sc = getSecurityContext(); if(sc == null) return null; return sc.getUtil().getCredential(); } public SecurityContext getSecurityContext() { return SecurityContextAssociation.getSecurityContext(); } }; SecurityActions PRIVILEGED = new SecurityActions() { private final PrivilegedAction getPrincipalAction = new PrivilegedAction() { public Object run() { SecurityContext sc = getSecurityContext(); if(sc == null) return null; return sc.getUtil().getUserPrincipal(); } }; private final PrivilegedAction getCredentialAction = new PrivilegedAction() { public Object run() { SecurityContext sc = getSecurityContext(); if(sc == null) return null; return sc.getUtil().getCredential(); } }; public Principal getPrincipal() { return (Principal)AccessController.doPrivileged(getPrincipalAction); } public Object getCredential() { return AccessController.doPrivileged(getCredentialAction); } public SecurityContext getSecurityContext() { return (SecurityContext)AccessController.doPrivileged( new PrivilegedAction(){ public Object run() { return SecurityContextAssociation.getSecurityContext(); }}); } }; Principal getPrincipal(); Object getCredential(); SecurityContext getSecurityContext(); } interface TCLAction { class UTIL { static TCLAction getTCLAction() { return System.getSecurityManager() == null ? NON_PRIVILEGED : PRIVILEGED; } static ClassLoader getContextClassLoader() { return getTCLAction().getContextClassLoader(); } static ClassLoader getContextClassLoader(Thread thread) { return getTCLAction().getContextClassLoader(thread); } static void setContextClassLoader(ClassLoader cl) { getTCLAction().setContextClassLoader(cl); } static void setContextClassLoader(Thread thread, ClassLoader cl) { getTCLAction().setContextClassLoader(thread, cl); } } TCLAction NON_PRIVILEGED = new TCLAction() { public ClassLoader getContextClassLoader() { return Thread.currentThread().getContextClassLoader(); } public ClassLoader getContextClassLoader(Thread thread) { return thread.getContextClassLoader(); } public void setContextClassLoader(ClassLoader cl) { Thread.currentThread().setContextClassLoader(cl); } public void setContextClassLoader(Thread thread, ClassLoader cl) { thread.setContextClassLoader(cl); } }; TCLAction PRIVILEGED = new TCLAction() { private final PrivilegedAction getTCLPrivilegedAction = new PrivilegedAction() { public Object run() { return Thread.currentThread().getContextClassLoader(); } }; public ClassLoader getContextClassLoader() { return (ClassLoader)AccessController.doPrivileged(getTCLPrivilegedAction); } public ClassLoader getContextClassLoader(final Thread thread) { return (ClassLoader)AccessController.doPrivileged(new PrivilegedAction() { public Object run() { return thread.getContextClassLoader(); } }); } public void setContextClassLoader(final ClassLoader cl) { AccessController.doPrivileged( new PrivilegedAction() { public Object run() { Thread.currentThread().setContextClassLoader(cl); return null; } } ); } public void setContextClassLoader(final Thread thread, final ClassLoader cl) { AccessController.doPrivileged( new PrivilegedAction() { public Object run() { thread.setContextClassLoader(cl); return null; } } ); } }; ClassLoader getContextClassLoader(); ClassLoader getContextClassLoader(Thread thread); void setContextClassLoader(ClassLoader cl); void setContextClassLoader(Thread thread, ClassLoader cl); } }