/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.activemq.artemis.tests.unit.util; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; import java.io.Serializable; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.net.URL; import java.net.URLClassLoader; import java.security.CodeSource; import java.security.ProtectionDomain; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.StringTokenizer; import org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.EnclosingClass; import org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass1; import org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass2; import org.apache.activemq.artemis.tests.util.ActiveMQTestBase; import org.apache.activemq.artemis.utils.ObjectInputStreamWithClassLoader; import org.junit.Assert; import org.junit.Test; public class ObjectInputStreamWithClassLoaderTest extends ActiveMQTestBase { // Constants ----------------------------------------------------- // Attributes ---------------------------------------------------- // Static -------------------------------------------------------- public static ClassLoader newClassLoader(final Class... userClasses) throws Exception { Set<URL> userClassUrls = new HashSet<>(); for (Class anyUserClass : userClasses) { ProtectionDomain protectionDomain = anyUserClass.getProtectionDomain(); CodeSource codeSource = protectionDomain.getCodeSource(); URL classLocation = codeSource.getLocation(); userClassUrls.add(classLocation); } StringTokenizer tokenString = new StringTokenizer(System.getProperty("java.class.path"), File.pathSeparator); String pathIgnore = System.getProperty("java.home"); if (pathIgnore == null) { pathIgnore = userClassUrls.iterator().next().toString(); } List<URL> urls = new ArrayList<>(); while (tokenString.hasMoreElements()) { String value = tokenString.nextToken(); URL itemLocation = new File(value).toURI().toURL(); if (!userClassUrls.contains(itemLocation) && itemLocation.toString().indexOf(pathIgnore) >= 0) { urls.add(itemLocation); } } URL[] urlArray = urls.toArray(new URL[urls.size()]); ClassLoader masterClassLoader = URLClassLoader.newInstance(urlArray, null); ClassLoader appClassLoader = URLClassLoader.newInstance(userClassUrls.toArray(new URL[0]), masterClassLoader); return appClassLoader; } // Constructors -------------------------------------------------- // Public -------------------------------------------------------- @Test public void testClassLoaderIsolation() throws Exception { ClassLoader originalClassLoader = Thread.currentThread().getContextClassLoader(); try { AnObject obj = new AnObjectImpl(); byte[] bytes = ObjectInputStreamWithClassLoaderTest.toBytes(obj); //Class.isAnonymousClass() call used in ObjectInputStreamWithClassLoader //need to access the enclosing class and its parent class of the obj //i.e. ActiveMQTestBase and Assert. ClassLoader testClassLoader = ObjectInputStreamWithClassLoaderTest.newClassLoader(obj.getClass(), ActiveMQTestBase.class, Assert.class); Thread.currentThread().setContextClassLoader(testClassLoader); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ObjectInputStreamWithClassLoader ois = new ObjectInputStreamWithClassLoader(bais); Object deserializedObj = ois.readObject(); Assert.assertNotSame(obj, deserializedObj); Assert.assertNotSame(obj.getClass(), deserializedObj.getClass()); Assert.assertNotSame(obj.getClass().getClassLoader(), deserializedObj.getClass().getClassLoader()); Assert.assertSame(testClassLoader, deserializedObj.getClass().getClassLoader()); } finally { Thread.currentThread().setContextClassLoader(originalClassLoader); } } @Test public void testClassLoaderIsolationWithProxy() throws Exception { ClassLoader originalClassLoader = Thread.currentThread().getContextClassLoader(); try { AnObject originalProxy = (AnObject) Proxy.newProxyInstance(AnObject.class.getClassLoader(), new Class[]{AnObject.class}, new AnObjectInvocationHandler()); originalProxy.setMyInt(100); byte[] bytes = ObjectInputStreamWithClassLoaderTest.toBytes(originalProxy); ClassLoader testClassLoader = ObjectInputStreamWithClassLoaderTest.newClassLoader(this.getClass(), ActiveMQTestBase.class, Assert.class); Thread.currentThread().setContextClassLoader(testClassLoader); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ObjectInputStreamWithClassLoader ois = new ObjectInputStreamWithClassLoader(bais); Runnable toRun = (Runnable) testClassLoader.loadClass(ProxyReader.class.getName()).newInstance(); toRun.getClass().getField("ois").set(toRun, ois); toRun.getClass().getField("testClassLoader").set(toRun, testClassLoader); toRun.getClass().getField("originalProxy").set(toRun, originalProxy); toRun.run(); } finally { Thread.currentThread().setContextClassLoader(originalClassLoader); } } @Test public void testWhiteBlackList() throws Exception { File serailizeFile = new File(temporaryFolder.getRoot(), "testclass.bin"); ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(serailizeFile)); try { outputStream.writeObject(new TestClass1()); outputStream.flush(); } finally { outputStream.close(); } //default assertNull(readSerializedObject(null, null, serailizeFile)); //white list String whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization"; assertNull(readSerializedObject(whiteList, null, serailizeFile)); whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1"; assertNull(readSerializedObject(whiteList, null, serailizeFile)); whiteList = "some.other.package"; Exception result = readSerializedObject(whiteList, null, serailizeFile); assertTrue(result instanceof ClassNotFoundException); //blacklist String blackList = "org.apache.activemq.artemis.tests.unit.util"; result = readSerializedObject(null, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); blackList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1"; result = readSerializedObject(null, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); blackList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg2"; result = readSerializedObject(null, blackList, serailizeFile); assertNull(result); blackList = "some.other.package"; whiteList = "some.other.package1"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); //blacklist priority blackList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1, some.other.package"; whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); blackList = "org.apache.activemq.artemis.tests.unit, some.other.package"; whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); blackList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.pkg2, some.other.package"; whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertNull(result); blackList = "some.other.package, org.apache.activemq.artemis.tests.unit.util.deserialization.pkg2"; whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertNull(result); //wildcard blackList = "*"; whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); blackList = "*"; whiteList = "*"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); result = readSerializedObject(whiteList, null, serailizeFile); assertNull(result); } @Test public void testWhiteBlackListAgainstArrayObject() throws Exception { File serailizeFile = new File(temporaryFolder.getRoot(), "testclass.bin"); TestClass1[] sourceObject = new TestClass1[]{new TestClass1()}; ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(serailizeFile)); try { outputStream.writeObject(sourceObject); outputStream.flush(); } finally { outputStream.close(); } //default ok String blackList = null; String whiteList = null; Object result = readSerializedObject(whiteList, blackList, serailizeFile); assertNull(result); //now blacklist TestClass1 blackList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass1"; whiteList = null; result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); //now whitelist TestClass1, it should pass. blackList = null; whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass1"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertNull(result); } @Test public void testWhiteBlackListAgainstListObject() throws Exception { File serailizeFile = new File(temporaryFolder.getRoot(), "testclass.bin"); List<TestClass1> sourceObject = new ArrayList<>(); sourceObject.add(new TestClass1()); ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(serailizeFile)); try { outputStream.writeObject(sourceObject); outputStream.flush(); } finally { outputStream.close(); } //default ok String blackList = null; String whiteList = null; Object result = readSerializedObject(whiteList, blackList, serailizeFile); assertNull(result); //now blacklist TestClass1 blackList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass1"; whiteList = null; result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); //now whitelist TestClass1, should fail because the List type is not allowed blackList = null; whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass1"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); //now add List to white list, it should pass blackList = null; whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass1," + "java.util.ArrayList"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertNull(result); } @Test public void testWhiteBlackListAgainstListMapObject() throws Exception { File serailizeFile = new File(temporaryFolder.getRoot(), "testclass.bin"); Map<TestClass1, TestClass2> sourceObject = new HashMap<>(); sourceObject.put(new TestClass1(), new TestClass2()); ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(serailizeFile)); try { outputStream.writeObject(sourceObject); outputStream.flush(); } finally { outputStream.close(); } String blackList = null; String whiteList = null; Object result = readSerializedObject(whiteList, blackList, serailizeFile); assertNull(result); //now blacklist the key blackList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass1"; whiteList = null; result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); //now blacklist the value blackList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass2"; whiteList = null; result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); //now white list the key, should fail too because value is forbidden blackList = null; whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass1"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); //now white list the value, should fail too because the key is forbidden blackList = null; whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass2"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); //both key and value are in the whitelist, it should fail because HashMap not permitted blackList = null; whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass1," + "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass2"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); //now add HashMap, test should pass. blackList = null; whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass1," + "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.TestClass2," + "java.util.HashMap"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertNull(result); } @Test public void testWhiteBlackListAnonymousObject() throws Exception { File serailizeFile = new File(temporaryFolder.getRoot(), "testclass.bin"); ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(serailizeFile)); try { Serializable object = EnclosingClass.anonymousObject; assertTrue(object.getClass().isAnonymousClass()); outputStream.writeObject(object); outputStream.flush(); } finally { outputStream.close(); } //default String blackList = null; String whiteList = null; assertNull(readSerializedObject(whiteList, blackList, serailizeFile)); //forbidden by specifying the enclosing class blackList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.EnclosingClass"; Object result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); //do it in whiteList blackList = null; whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.EnclosingClass"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertNull(result); } @Test public void testWhiteBlackListLocalObject() throws Exception { File serailizeFile = new File(temporaryFolder.getRoot(), "testclass.bin"); ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(serailizeFile)); try { Object object = EnclosingClass.getLocalObject(); assertTrue(object.getClass().isLocalClass()); outputStream.writeObject(object); outputStream.flush(); } finally { outputStream.close(); } //default String blackList = null; String whiteList = null; assertNull(readSerializedObject(whiteList, blackList, serailizeFile)); //forbidden by specifying the enclosing class blackList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.EnclosingClass"; Object result = readSerializedObject(whiteList, blackList, serailizeFile); assertTrue(result instanceof ClassNotFoundException); //do it in whiteList blackList = null; whiteList = "org.apache.activemq.artemis.tests.unit.util.deserialization.pkg1.EnclosingClass"; result = readSerializedObject(whiteList, blackList, serailizeFile); assertNull(result); } @Test public void testWhiteBlackListSystemProperty() throws Exception { File serailizeFile = new File(temporaryFolder.getRoot(), "testclass.bin"); ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(serailizeFile)); try { outputStream.writeObject(new TestClass1()); outputStream.flush(); } finally { outputStream.close(); } System.setProperty(ObjectInputStreamWithClassLoader.BLACKLIST_PROPERTY, "system.defined.black.list"); System.setProperty(ObjectInputStreamWithClassLoader.WHITELIST_PROPERTY, "system.defined.white.list"); try { ObjectInputStreamWithClassLoader ois = new ObjectInputStreamWithClassLoader(new FileInputStream(serailizeFile)); String bList = ois.getBlackList(); String wList = ois.getWhiteList(); assertEquals("wrong black list: " + bList, "system.defined.black.list", bList); assertEquals("wrong white list: " + wList, "system.defined.white.list", wList); ois.close(); } finally { System.clearProperty(ObjectInputStreamWithClassLoader.BLACKLIST_PROPERTY); System.clearProperty(ObjectInputStreamWithClassLoader.WHITELIST_PROPERTY); } } private Exception readSerializedObject(String whiteList, String blackList, File serailizeFile) { Exception result = null; ObjectInputStreamWithClassLoader ois = null; try { ois = new ObjectInputStreamWithClassLoader(new FileInputStream(serailizeFile)); ois.setWhiteList(whiteList); ois.setBlackList(blackList); ois.readObject(); } catch (Exception e) { result = e; } finally { try { ois.close(); } catch (IOException e) { result = e; } } return result; } // Package protected --------------------------------------------- // Protected ----------------------------------------------------- // Private ------------------------------------------------------- public static class ProxyReader implements Runnable { public java.io.ObjectInputStream ois; public Object originalProxy; public ClassLoader testClassLoader; // We don't have access to the junit framework on the classloader where this is running void myAssertNotSame(Object obj, Object obj2) { if (obj == obj2) { throw new RuntimeException("Expected to be different objects"); } } // We don't have access to the junit framework on the classloader where this is running void myAssertSame(Object obj, Object obj2) { if (obj != obj2) { throw new RuntimeException("Expected to be the same objects"); } } @Override public void run() { try { Object deserializedObj = ois.readObject(); System.out.println("Deserialized Object " + deserializedObj); myAssertNotSame(originalProxy, deserializedObj); myAssertNotSame(originalProxy.getClass(), deserializedObj.getClass()); myAssertNotSame(originalProxy.getClass().getClassLoader(), deserializedObj.getClass().getClassLoader()); myAssertSame(testClassLoader, deserializedObj.getClass().getClassLoader()); AnObject myInterface = (AnObject) deserializedObj; if (myInterface.getMyInt() != 200) { throw new RuntimeException("invalid result"); } } catch (ClassNotFoundException e) { throw new RuntimeException(e.getMessage(), e); } catch (IOException e) { throw new RuntimeException(e.getMessage(), e); } } } private static byte[] toBytes(final Object obj) throws IOException { Assert.assertTrue(obj instanceof Serializable); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos); oos.writeObject(obj); oos.flush(); return baos.toByteArray(); } // Inner classes ------------------------------------------------- private interface AnObject extends Serializable { int getMyInt(); void setMyInt(int value); long getMyLong(); void setMyLong(long value); } private static class AnObjectImpl implements AnObject { private static final long serialVersionUID = -5172742084489525256L; int myInt = 0; long myLong = 0L; @Override public int getMyInt() { return myInt; } @Override public void setMyInt(int value) { this.myInt = value; } @Override public long getMyLong() { return myLong; } @Override public void setMyLong(long value) { this.myLong = value; } } private static class AnObjectInvocationHandler implements InvocationHandler, Serializable { private static final long serialVersionUID = -3875973764178767452L; private final AnObject anObject = new AnObjectImpl(); @Override public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { Object obj = method.invoke(anObject, args); if (obj instanceof Integer) { return ((Integer) obj).intValue() * 2; } else { return obj; } } } }