/** * Copyright 2005 JBoss Inc * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.rule; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.Externalizable; import java.io.IOException; import java.io.InputStream; import java.io.ObjectInput; import java.io.ObjectInputStream; import java.io.ObjectOutput; import java.io.ObjectOutputStream; import java.net.URL; import java.security.AccessController; import java.security.InvalidKeyException; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.PrivilegedAction; import java.security.ProtectionDomain; import java.security.SignatureException; import java.util.ArrayList; import java.util.Collections; import java.util.Enumeration; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import org.drools.RuntimeDroolsException; import org.drools.common.DroolsObjectInputStream; import org.drools.core.util.DroolsClassLoader; import org.drools.core.util.KeyStoreHelper; import org.drools.core.util.StringUtils; import org.drools.spi.Wireable; import org.drools.util.CompositeClassLoader; import org.drools.util.FastClassLoader; public class JavaDialectRuntimeData implements DialectRuntimeData, Externalizable { /** * */ private static final long serialVersionUID = 510l; private static final ProtectionDomain PROTECTION_DOMAIN; private Map invokerLookups; private Map store; private DialectRuntimeRegistry registry; private transient PackageClassLoader classLoader; private transient CompositeClassLoader rootClassLoader; private boolean dirty; private List<String> wireList = Collections.<String> emptyList(); static { PROTECTION_DOMAIN = (ProtectionDomain) AccessController.doPrivileged( new PrivilegedAction() { public Object run() { return JavaDialectRuntimeData.class.getProtectionDomain(); } } ); } public JavaDialectRuntimeData() { this.invokerLookups = new HashMap(); this.store = new HashMap(); this.dirty = false; } /** * Handles the write serialization of the PackageCompilationData. Patterns in Rules may reference generated data which cannot be serialized by * default methods. The PackageCompilationData holds a reference to the generated bytecode. The generated bytecode must be restored before any Rules. * */ public void writeExternal(ObjectOutput stream) throws IOException { KeyStoreHelper helper = new KeyStoreHelper(); stream.writeBoolean( helper.isSigned() ); if ( helper.isSigned() ) { stream.writeObject( helper.getPvtKeyAlias() ); } ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutput out = new ObjectOutputStream( bos ); out.writeInt( this.store.size() ); for ( Iterator it = this.store.entrySet().iterator(); it.hasNext(); ) { Entry entry = (Entry) it.next(); out.writeObject( entry.getKey() ); out.writeObject( entry.getValue() ); } out.flush(); out.close(); byte[] buff = bos.toByteArray(); stream.writeObject( buff ); if ( helper.isSigned() ) { sign( stream, helper, buff ); } stream.writeInt( this.invokerLookups.size() ); for ( Iterator it = this.invokerLookups.entrySet().iterator(); it.hasNext(); ) { Entry entry = (Entry) it.next(); stream.writeObject( entry.getKey() ); stream.writeObject( entry.getValue() ); } } private void sign(final ObjectOutput stream, KeyStoreHelper helper, byte[] buff) { try { stream.writeObject( helper.signDataWithPrivateKey( buff ) ); } catch ( Exception e ) { throw new RuntimeDroolsException( "Error signing object store: " + e.getMessage(), e ); } } /** * Handles the read serialization of the PackageCompilationData. Patterns in Rules may reference generated data which cannot be serialized by * default methods. The PackageCompilationData holds a reference to the generated bytecode; which must be restored before any Rules. * A custom ObjectInputStream, able to resolve classes against the bytecode, is used to restore the Rules. * */ public void readExternal(ObjectInput stream) throws IOException, ClassNotFoundException { KeyStoreHelper helper = new KeyStoreHelper(); boolean signed = stream.readBoolean(); if ( helper.isSigned() != signed ) { throw new RuntimeDroolsException( "This environment is configured to work with " + (helper.isSigned() ? "signed" : "unsigned") + " serialized objects, but the given object is " + (signed ? "signed" : "unsigned") + ". Deserialization aborted." ); } String pubKeyAlias = null; if ( signed ) { pubKeyAlias = (String) stream.readObject(); if ( helper.getPubKeyStore() == null ) { throw new RuntimeDroolsException( "The package was serialized with a signature. Please configure a public keystore with the public key to check the signature. Deserialization aborted." ); } } // Return the object stored as a byte[] byte[] bytes = (byte[]) stream.readObject(); if ( signed ) { checkSignature( stream, helper, bytes, pubKeyAlias ); } ObjectInputStream in = new ObjectInputStream( new ByteArrayInputStream( bytes ) ); for ( int i = 0, length = in.readInt(); i < length; i++ ) { this.store.put( in.readObject(), in.readObject() ); } in.close(); for ( int i = 0, length = stream.readInt(); i < length; i++ ) { this.invokerLookups.put( stream.readObject(), stream.readObject() ); } // mark it as dirty, so that it reloads everything. this.dirty = true; } private void checkSignature(final ObjectInput stream, final KeyStoreHelper helper, final byte[] bytes, final String pubKeyAlias) throws ClassNotFoundException, IOException { byte[] signature = (byte[]) stream.readObject(); try { if ( !helper.checkDataWithPublicKey( pubKeyAlias, bytes, signature ) ) { throw new RuntimeDroolsException( "Signature does not match serialized package. This is a security violation. Deserialisation aborted." ); } } catch ( InvalidKeyException e ) { throw new RuntimeDroolsException( "Invalid key checking signature: " + e.getMessage(), e ); } catch ( KeyStoreException e ) { throw new RuntimeDroolsException( "Error accessing Key Store: " + e.getMessage(), e ); } catch ( NoSuchAlgorithmException e ) { throw new RuntimeDroolsException( "No algorithm available: " + e.getMessage(), e ); } catch ( SignatureException e ) { throw new RuntimeDroolsException( "Signature Exception: " + e.getMessage(), e ); } } public void onAdd(DialectRuntimeRegistry registry, CompositeClassLoader rootClassLoader) { this.registry = registry; this.rootClassLoader = rootClassLoader; this.classLoader = new PackageClassLoader( this, this.rootClassLoader ); this.rootClassLoader.addClassLoader( this.classLoader ); } public void onRemove() { this.rootClassLoader.removeClassLoader( this.classLoader ); } public void onBeforeExecute() { if ( isDirty() ) { reload(); } else if ( !this.wireList.isEmpty() ) { try { // wire all remaining resources for ( String resourceName : this.wireList ) { wire( convertResourceToClassName( resourceName ) ); } } catch ( Exception e ) { throw new RuntimeDroolsException( "Unable to wire up JavaDialect", e ); } } this.wireList.clear(); } public DialectRuntimeData clone(DialectRuntimeRegistry registry, CompositeClassLoader rootClassLoader) { DialectRuntimeData cloneOne = new JavaDialectRuntimeData(); cloneOne.merge( registry, this ); cloneOne.onAdd( registry, rootClassLoader ); return cloneOne; } public void merge(DialectRuntimeRegistry registry, DialectRuntimeData newData) { this.registry = registry; JavaDialectRuntimeData newJavaData = (JavaDialectRuntimeData) newData; // First update the binary files // @todo: this probably has issues if you add classes in the incorrect order - functions, rules, invokers. for ( String resourceName : newJavaData.list() ) { write( resourceName, newJavaData.read( resourceName ) ); // // no need to wire, as we already know this is done in a merge // if ( getStore().put( resourceName, // newJavaData.read( resourceName ) ) != null ) { // // we are updating an existing class so reload(); // this.dirty = true; // } // if ( this.dirty == false ) { // // only build up the wireList if we aren't going to reload // this.wireList.add( resourceName ); // } } // if ( this.dirty ) { // // no need to keep wireList if we are going to reload; // this.wireList.clear(); // } // Add invokers putAllInvokers( newJavaData.getInvokers() ); } public boolean isDirty() { return this.dirty; } public void setDirty(boolean dirty) { this.dirty = dirty; } protected Map getStore() { if ( store == null ) { store = new HashMap(); } return store; } public ClassLoader getClassLoader() { return this.classLoader; } public void removeRule(Package pkg, Rule rule) { if ( !(rule instanceof Query) ) { // Query's don't have a consequence, so skip those final String consequenceName = rule.getConsequence().getClass().getName(); // check for compiled code and remove if present. if ( remove( consequenceName ) ) { removeClasses( rule.getLhs() ); // Now remove the rule class - the name is a subset of the consequence name String sufix = StringUtils.ucFirst( rule.getConsequence().getName() ) + "ConsequenceInvoker"; remove( consequenceName.substring( 0, consequenceName.indexOf( sufix ) ) ); } } } public void removeFunction(Package pkg, Function function) { remove( pkg.getName() + "." + StringUtils.ucFirst( function.getName() ) ); } private void removeClasses(final ConditionalElement ce) { if ( ce instanceof GroupElement ) { final GroupElement group = (GroupElement) ce; for ( final Iterator it = group.getChildren().iterator(); it.hasNext(); ) { final Object object = it.next(); if ( object instanceof ConditionalElement ) { removeClasses( (ConditionalElement) object ); } else if ( object instanceof Pattern ) { removeClasses( (Pattern) object ); } } } else if ( ce instanceof EvalCondition ) { remove( ((EvalCondition) ce).getEvalExpression().getClass().getName() ); } } private void removeClasses(final Pattern pattern) { for ( final Iterator it = pattern.getConstraints().iterator(); it.hasNext(); ) { final Object object = it.next(); if ( object instanceof PredicateConstraint ) { remove( ((PredicateConstraint) object).getPredicateExpression().getClass().getName() ); } else if ( object instanceof ReturnValueConstraint ) { remove( ((ReturnValueConstraint) object).getExpression().getClass().getName() ); } } } public byte[] read(final String resourceName) { byte[] bytes = null; if ( !getStore().isEmpty() ) { bytes = (byte[]) getStore().get( resourceName ); } return bytes; } public void write(final String resourceName, final byte[] clazzData) throws RuntimeDroolsException { if ( getStore().put( resourceName, clazzData ) != null ) { this.dirty = true; if ( !this.wireList.isEmpty() ) { this.wireList.clear(); } } else if ( !this.dirty ) { try { if ( this.wireList == Collections.<String> emptyList() ) { this.wireList = new ArrayList<String>(); } this.wireList.add( resourceName ); } catch ( final Exception e ) { e.printStackTrace(); throw new RuntimeDroolsException( e ); } } } public void wire(final String className) throws ClassNotFoundException, InstantiationException, IllegalAccessException { final Object invoker = getInvokers().get( className ); if ( invoker != null ) { wire( className, invoker ); } } public void wire(final String className, final Object invoker) throws ClassNotFoundException, InstantiationException, IllegalAccessException { final Class clazz = this.rootClassLoader.loadClass( className ); if ( clazz != null ) { if ( invoker instanceof Wireable ) { ((Wireable) invoker).wire( clazz.newInstance() ); } // //if ( invoker instanceof ReturnValueRestriction ) { //((ReturnValueRestriction) invoker).setReturnValueExpression( (ReturnValueExpression) clazz.newInstance() ); //} else if ( invoker instanceof PredicateConstraint ) { //((PredicateConstraint) invoker).setPredicateExpression( (PredicateExpression) clazz.newInstance() ); //} else if ( invoker instanceof EvalCondition ) { //((EvalCondition) invoker).setEvalExpression( (EvalExpression) clazz.newInstance() ); //} else if ( invoker instanceof Accumulate ) { //((Accumulate) invoker).setAccumulator( (Accumulator) clazz.newInstance() ); //} else if ( invoker instanceof Rule ) { //((Rule) invoker).setConsequence( (Consequence) clazz.newInstance() ); //} else if ( invoker instanceof JavaAccumulatorFunctionExecutor ) { //((JavaAccumulatorFunctionExecutor) invoker).setExpression( (ReturnValueExpression) clazz.newInstance() ); //} else if ( invoker instanceof DroolsAction ) { //((DroolsAction) invoker).setMetaData( "Action", // clazz.newInstance() ); //} else if ( invoker instanceof ReturnValueConstraintEvaluator ) { //((ReturnValueConstraintEvaluator) invoker).setEvaluator( (ReturnValueEvaluator) clazz.newInstance() ); //} } else { throw new ClassNotFoundException( className ); } } public boolean remove(final String resourceName) throws RuntimeDroolsException { getInvokers().remove( resourceName ); if ( getStore().remove( convertClassToResourcePath( resourceName ) ) != null ) { this.wireList.remove( resourceName ); // we need to make sure the class is removed from the classLoader // reload(); this.dirty = true; return true; } return false; } public String[] list() { String[] names = new String[getStore().size()]; int i = 0; for ( Object object : getStore().keySet() ) { names[i++] = (String) object; } return names; } /** * This class drops the classLoader and reloads it. During this process it must re-wire all the invokeables. * @throws RuntimeDroolsException */ public void reload() throws RuntimeDroolsException { // drops the classLoader and adds a new one this.rootClassLoader.removeClassLoader( this.classLoader ); this.classLoader = new PackageClassLoader( this, this.rootClassLoader ); this.rootClassLoader.addClassLoader( this.classLoader ); // Wire up invokers try { for ( final Object object : getInvokers().entrySet() ) { Entry entry = (Entry) object; wire( (String) entry.getKey(), entry.getValue() ); } } catch ( final ClassNotFoundException e ) { throw new RuntimeDroolsException( e ); } catch ( final InstantiationError e ) { throw new RuntimeDroolsException( e ); } catch ( final IllegalAccessException e ) { throw new RuntimeDroolsException( e ); } catch ( final InstantiationException e ) { throw new RuntimeDroolsException( e ); } this.dirty = false; } public void clear() { getStore().clear(); getInvokers().clear(); reload(); } public String toString() { return this.getClass().getName() + getStore().toString(); } public void putInvoker(final String className, final Object invoker) { getInvokers().put( className, invoker ); } public void putAllInvokers(final Map invokers) { getInvokers().putAll( invokers ); } public Map getInvokers() { if ( this.invokerLookups == null ) { this.invokerLookups = new HashMap(); } return this.invokerLookups; } public void removeInvoker(final String className) { getInvokers().remove( className ); } /** * This is an Internal Drools Class * */ public static class PackageClassLoader extends ClassLoader implements FastClassLoader { private JavaDialectRuntimeData store; CompositeClassLoader rootClassLoader; public PackageClassLoader(JavaDialectRuntimeData store, CompositeClassLoader rootClassLoader) { super( rootClassLoader ); this.rootClassLoader = rootClassLoader; this.store = store; } public Class< ? > loadClass(final String name, final boolean resolve) throws ClassNotFoundException { Class< ? > cls = fastFindClass( name ); if ( cls == null ) { final CompositeClassLoader parent = ( CompositeClassLoader ) getParent(); cls = parent.loadClass( name, resolve, this ); } if ( cls == null ) { throw new ClassNotFoundException("Unable to load class: " + name); } return cls; } public Class< ? > fastFindClass(final String name) { Class< ? > cls = findLoadedClass( name ); if ( cls == null ) { final byte[] clazzBytes = this.store.read( convertClassToResourcePath( name ) ); if ( clazzBytes != null ) { String pkgName = name.substring( 0, name.lastIndexOf( '.' ) ); if ( getPackage( pkgName ) == null ) { definePackage( pkgName, "", "", "", "", "", "", null ); } cls = defineClass( name, clazzBytes, 0, clazzBytes.length, PROTECTION_DOMAIN ); } if ( cls != null ) { resolveClass( cls ); } } return cls; } public InputStream getResourceAsStream(final String name) { final byte[] clsBytes = this.store.read( name ); if ( clsBytes != null ) { return new ByteArrayInputStream( clsBytes ); } return null; } public URL getResource(String name) { return null; } public Enumeration<URL> getResources(String name) throws IOException { return null; } } /** * Please do not use - internal * org/my/Class.xxx -> org.my.Class */ public static String convertResourceToClassName(final String pResourceName) { return stripExtension( pResourceName ).replace( '/', '.' ); } /** * Please do not use - internal * org.my.Class -> org/my/Class.class */ public static String convertClassToResourcePath(final String pName) { return pName.replace( '.', '/' ) + ".class"; } /** * Please do not use - internal * org/my/Class.xxx -> org/my/Class */ public static String stripExtension(final String pResourceName) { final int i = pResourceName.lastIndexOf( '.' ); final String withoutExtension = pResourceName.substring( 0, i ); return withoutExtension; } }